Tensor Quaternion Module

Pytorch Models for Sequential Data

Quaternion Type


source

TensorQuaternionAngle

 TensorQuaternionAngle (x, **kwargs)

A Tensor which support subclass pickling, and maintains metadata when casting or after methods


source

TensorQuaternionInclination

 TensorQuaternionInclination (x, **kwargs)

A Tensor which support subclass pickling, and maintains metadata when casting or after methods

from nbdev.config import get_config
project_root = get_config().config_file.parent
f_path = project_root / 'test_data/orientation'
hdf_files = get_hdf_files(f_path)
tfm_src = CreateDict([DfHDFCreateWindows(win_sz=1000,stp_sz=100,clm='acc_x')])
u = ['acc_x','acc_y','acc_z','gyr_x','gyr_y','gyr_z']
# u = ['acc_x','acc_y','acc_z','gyr_x','gyr_y','gyr_z','mag_x','mag_y','mag_z']
y =['opt_a','opt_b','opt_c','opt_d']
dls = DataBlock(blocks=(SequenceBlock.from_hdf(u,TensorSequencesInput),
                        SequenceBlock.from_hdf(y,TensorQuaternionInclination)),
                get_items=tfm_src,
                splitter=RandomSplitter(0.1)
               ).dataloaders(hdf_files,shufflish=True,bs=128)

Basic Operations

tq1 = tensor([
    [1,0,0,0],
    [0.5,0.5,0.5,0.5],
    ])
tq2 = tensor([
    [0.5,0.5,0.5,0.5],
    [0.5,0.5,0.5,0.5],
    ])
tq1.shape
torch.Size([2, 4])

source

rad2deg

 rad2deg (t)
test_eq(float(rad2deg(_pi)),180)

ScriptFunction object at 0x7f83233af100>

quat1quat2*

# q = tq1.repeat(1000,1)
# %%timeit
# torch.cumprod(q,dim=-1)
# %%timeit
# [q*q for _ in range(1000)]
test_eq(multiplyQuat(tq1,tq2),tensor([[ 0.5000,  0.5000,  0.5000,  0.5000],
                                    [-0.5000,  0.5000,  0.5000,  0.5000]]))

source

norm_quaternion

 norm_quaternion (q)
test_eq(norm_quaternion(tq1*5),tq1)
test_eq(norm_quaternion(tq1/_pi),tq1)
test_eq(norm_quaternion(tq1[None,...]),tq1[None,...])

source

conjQuat

 conjQuat (q)
test_eq(conjQuat(tq1),tensor([[ 1.0000, -0.0000, -0.0000, -0.0000],
                             [ 0.5000, -0.5000, -0.5000, -0.5000]]))

ScriptFunction object at 0x7f83234a8220>

quat1inv(quat2)*


source

diffQuat

 diffQuat (q1, q2, norm=True)
test_eq(diffQuat(tq1,tq2),diffQuat(tq1,tq2*5))
test_ne(diffQuat(tq1,tq2),diffQuat(tq1,tq2*5,norm=False))
test_ne(diffQuat(tq1,tq2),diffQuat(tq1[None,...],tq2[None,...]))

source

safe_acos

 safe_acos (t, eps=4e-08)

numericaly stable variant of arcuscosine

test_ne(safe_acos(tensor(1.))*1e6,0)
test_eq(safe_acos(tensor(-0.)),_pi/2)

source

safe_acos_double

 safe_acos_double (t, eps=1e-16)

numericaly stable variant of arcuscosine, uses 64bit floats for internal computation for increased accuracy and gradient propagation

test_ne(safe_acos_double(tensor(1.))*1e6,0)
test_eq(safe_acos_double(tensor(-0.)),_pi/2)

source

relativeAngle

 relativeAngle (q1, q2)

source

inclinationAngle

 inclinationAngle (q1, q2)
print('inclination:', rad2deg(inclinationAngle(tq1,tq2)))
print('relative:', rad2deg(relativeAngle(tq1,tq2)))
inclination: tensor([9.0000e+01, 1.7075e-06])
relative: tensor([1.2000e+02, 1.7075e-06])

source

pitchAngle

 pitchAngle (q1, q2)

source

rollAngle

 rollAngle (q1, q2)
print('roll:', rad2deg(rollAngle(tq1,tq2)))
print('pitch:', rad2deg(pitchAngle(tq1,tq2)))
roll: tensor([0., 0.])
pitch: tensor([-90.,   0.])

source

inclinationAngleAbs

 inclinationAngleAbs (q)
rad2deg(inclinationAngleAbs(tq1))
tensor([ 0., 90.])

source

rand_quat

 rand_quat ()

source

rot_vec

 rot_vec (v, q)
g = tensor([[9.81,0,0]]*5)
r_quat = rand_quat()
rot_vec(g,r_quat)
tensor([[ 1.3754,  7.8102, -5.7746],
        [ 1.3754,  7.8102, -5.7746],
        [ 1.3754,  7.8102, -5.7746],
        [ 1.3754,  7.8102, -5.7746],
        [ 1.3754,  7.8102, -5.7746]])

source

quatFromAngleAxis

 quatFromAngleAxis (angle, axis)

source

quatInterp

 quatInterp (quat, ind, extend=False)

*Interpolates an array of quaternions of (non-integer) indices using Slerp. Sampling indices are in the range 0..N-1, for values outside of this range, depending on “extend”, the first/last element or NaN is returned.

See also csg_bigdata.dp.utils.vecInterp.

:param quat: array of input quaternions (N(xB)x4) :param ind: vector containing the sampling indices, shape (M,) :param extend: if true, the input data is virtually extended by the first/last value :return: interpolated quaternions (Mx4)*

1e-3/_pi*180
tensor([0.0573])
q = torch.rand((1000000,1,4))*2-1
q /= q.norm()
x = torch.linspace(0,q.shape[0]-1,150001)
q_i = quatInterp(q,x)
torch.isnan(inclinationAngleAbs(q_i)).sum()
tensor(0)
q_i = quatInterp(q,x)
plt.figure()
plt.plot(inclinationAngleAbs(q))
plt.plot(x,inclinationAngleAbs(q_i))

Loss Functions


source

inclination_loss

 inclination_loss (q1, q2)
inclination_loss(tq1,tq2)
tensor(0.2071)

source

inclination_loss_abs

 inclination_loss_abs (q1, q2)
inclination_loss_abs(tq1,tq2)
tensor(0.1464)

source

inclination_loss_squared

 inclination_loss_squared (q1, q2)
# %%timeit
inclination_loss_squared(tq1,tq2)
tensor(0.0429)

source

inclination_loss_smooth

 inclination_loss_smooth (q1, q2)
# %%timeit
inclination_loss_smooth(tq1,tq2)
tensor(0.0214)

source

abs_inclination

 abs_inclination (q1, q2)
abs_inclination(tq1,tq2)
tensor(0.7854)

source

ms_inclination

 ms_inclination (q1, q2)
ms_inclination(tq1,tq2)
tensor(1.2337)

source

rms_inclination

 rms_inclination (q1, q2)
rms_inclination(tq1,tq2)
tensor(1.1107)

source

smooth_inclination

 smooth_inclination (q1, q2)
smooth_inclination(tq1,tq2)
tensor(0.5354)

source

rms_inclination_deg

 rms_inclination_deg (q1, q2)
rms_inclination_deg(tq1,tq2)
tensor([63.6396])

source

rms_pitch_deg

 rms_pitch_deg (q1, q2)
rms_pitch_deg(tq1,tq2)
tensor([63.6396])

source

rms_roll_deg

 rms_roll_deg (q1, q2)
rms_roll_deg(tq1,tq2)
tensor([0.])

source

mean_inclination_deg

 mean_inclination_deg (q1, q2)
mean_inclination_deg(tq1,tq2)
tensor([45.])

source

angle_loss

 angle_loss (q1, q2)

source

angle_loss_opt

 angle_loss_opt (q1, q2)

source

ms_rel_angle

 ms_rel_angle (q1, q2)
ms_rel_angle(tq1,tq2)
tensor(2.1932)

source

abs_rel_angle

 abs_rel_angle (q1, q2)

source

rms_rel_angle_deg

 rms_rel_angle_deg (q1, q2)
rms_rel_angle_deg(tq1,tq2)
tensor([84.8528])

source

mean_rel_angle_deg

 mean_rel_angle_deg (q1, q2)
mean_rel_angle_deg(tq1,tq2)
tensor([60.])

source

deg_rmse

 deg_rmse (inp, targ)

Callbacks

In order to assure that the output of the model are close to unit quaternions the distance will be added to the loss


source

QuaternionRegularizer

 QuaternionRegularizer (reg_unit=0.0, detach=False, modules=None,
                        every=None, remove_end=True, is_forward=True,
                        cpu=True, include_paramless=False, hook=None)

Callback that adds AR and TAR to the loss, calculated by output of provided layer


source

augmentation_groups

 augmentation_groups (u_groups)

returns the rotation list corresponding to the input groups

u_raw_groups = [3,3]
test_eq(augmentation_groups(u_raw_groups),[[0,2],[3,5]])

source

QuaternionAugmentation

 QuaternionAugmentation (inp_groups, **kwargs)

A transform that before_call its state at each __call__

n_skip = 2**8


inp,out = get_inp_out_size(dls)
# model = SimpleGRU(inp,out,num_layers=1,hidden_size=100)
model = TCN(inp,out,hl_depth=8,hl_width=10)

skip = partial(SkipNLoss,n_skip=n_skip)
metrics=rms_inclination_deg
cbs = [QuaternionRegularizer(reg_unit=1,modules=[model])]

lrn = Learner(dls,model,loss_func=ms_inclination,opt_func=ranger,metrics=metrics)
/opt/homebrew/Caskroom/miniforge/base/envs/env_tsfast/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
  WeightNorm.apply(module, name, dim)
lrn.fit_one_cycle(1,lr_max=3e-3)
epoch train_loss valid_loss rms_inclination_deg time
0 2.289876 1.290059 65.076988 00:01
/var/folders/pc/13zbh_m514n1tp522cx9npt00000gn/T/ipykernel_31206/3976315585.py:9: UserWarning: Float64 precision not supported on mps:0 device. Falling back to float32. This may reduce numerical accuracy of quaternion operations. Error: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  warnings.warn(f"Float64 precision not supported on {t.device} device. Falling back to float32. This may reduce numerical accuracy of quaternion operations. Error: {e}")

Resampling Model


source

Quaternion_ResamplingModel

 Quaternion_ResamplingModel (model, fs_targ, fs_mean=0, fs_std=1,
                             quaternion_sampling=True)

*Module that resamples the signal before and after the prediction of its model. Usefull for using models on datasets with different samplingrates.

sampling_method: method used for resampling [‘resample’,‘interpolate’]*

# dls = DataBlock(blocks=(SequenceBlock.from_hdf(u + ['dt'],TensorSequencesInput),
#                         SequenceBlock.from_hdf(y,TensorQuaternionInclination)),
#                 get_items=tfm_src,
#                 splitter=ApplyToDict(FuncSplitter(lambda o: 'experiment2' in str(o)))
#                ).dataloaders(hdf_files,shufflish=True,bs=128)
# model = TCN(inp,out,hl_depth=8,hl_width=10)
# Learner(dls,Quaternion_ResamplingModel(model,10,quaternion_sampling=False),loss_func=ms_inclination).fit(1)

Quaternion Datablock


source

HDF2Quaternion

 HDF2Quaternion (clm_names, clm_shift=None, truncate_sz=None,
                 to_cls=<function noop>, cached=True, fs_idx=None,
                 dt_idx=None, fast_resample=True)

Delegates (__call__,decode,setup) to (encodes,decodes,setups) if split_idx matches


source

QuaternionBlock

 QuaternionBlock (seq_extract, padding=False)

A basic wrapper that links defaults transforms for the data block API

tfm_src = CreateDict([DfResamplingFactor(2000/7,np.linspace(50,500,10)),DfHDFCreateWindows(win_sz=1000,stp_sz=100,clm='acc_x')])
dls = DataBlock(blocks=(SequenceBlock.from_hdf(u,TensorSequencesInput),
                        QuaternionBlock.from_hdf(y)),
                get_items=tfm_src,
                splitter=RandomSplitter(0.5)
               ).dataloaders(hdf_files,bs=2)
#test_eq(len(dls.items),83877)

Inclination Datablock


source

TensorInclination

 TensorInclination (x, **kwargs)

A Tensor which support subclass pickling, and maintains metadata when casting or after methods


source

HDF2Inclination

 HDF2Inclination (clm_names, clm_shift=None, truncate_sz=None,
                  to_cls=<function noop>, cached=True, fs_idx=None,
                  dt_idx=None, fast_resample=True)

Delegates (__call__,decode,setup) to (encodes,decodes,setups) if split_idx matches


source

InclinationBlock

 InclinationBlock (seq_extract, padding=False)

A basic wrapper that links defaults transforms for the data block API

# f_paths = '/mnt/Data/Systemidentification/Orientation_Estimation/'
# hdf_files = get_hdf_files(f_paths)
# tfm_src = CreateDict([DfHDFCreateWindows(win_sz=1000,stp_sz=100,clm='acc_x')])
# u = ['acc_x','acc_y','acc_z','gyr_x','gyr_y','gyr_z']
# # u = ['acc_x','acc_y','acc_z','gyr_x','gyr_y','gyr_z','mag_x','mag_y','mag_z']
# y =['opt_a','opt_b','opt_c','opt_d']
# dls = DataBlock(blocks=(SequenceBlock.from_hdf(u),
#                         InclinationBlock.from_hdf(y)),
#                 get_items=tfm_src,
#                 splitter=ApplyToDict(FuncSplitter(lambda o: 'experiment2' in str(o)))
#                ).dataloaders(hdf_files,shufflish=True,bs=128)

Show Results


source

plot_scalar_inclination

 plot_scalar_inclination (axs, in_sig, targ_sig, out_sig=None, **kwargs)

source

plot_quaternion_inclination

 plot_quaternion_inclination (axs, in_sig, targ_sig, out_sig=None,
                              **kwargs)

source

plot_quaternion_rel_angle

 plot_quaternion_rel_angle (axs, in_sig, targ_sig, out_sig=None, **kwargs)
dls.show_batch(max_n=3,ds_idx=0)

lrn.show_results(max_n=3,ds_idx=0,shuffle=True,quat=True)
/var/folders/pc/13zbh_m514n1tp522cx9npt00000gn/T/ipykernel_31206/3976315585.py:9: UserWarning: Float64 precision not supported on mps:0 device. Falling back to float32. This may reduce numerical accuracy of quaternion operations. Error: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
  warnings.warn(f"Float64 precision not supported on {t.device} device. Falling back to float32. This may reduce numerical accuracy of quaternion operations. Error: {e}")