Skip to content

Quaternions

Quaternion math, loss functions, and augmentations.

QuaternionRegularizer

QuaternionRegularizer(modules: list, reg_unit: float = 0.0)

Regularization loss that penalizes non-unit quaternion outputs.

Parameters:

Name Type Description Default
modules list

list of nn.Module instances whose outputs are captured via hooks.

required
reg_unit float

weight for the unit-norm regularization term.

0.0
Source code in tsfast/quaternions/aux_losses.py
def __init__(self, modules: list, reg_unit: float = 0.0):
    self.modules = modules
    self.reg_unit = reg_unit
    self._hooks: list = []
    self._captured: torch.Tensor | None = None

setup

setup(trainer)

Register forward hooks on the target modules.

Source code in tsfast/quaternions/aux_losses.py
def setup(self, trainer):
    """Register forward hooks on the target modules."""
    for m in self.modules:
        self._hooks.append(m.register_forward_hook(self._hook_fn))

teardown

teardown(trainer)

Remove all registered hooks.

Source code in tsfast/quaternions/aux_losses.py
def teardown(self, trainer):
    """Remove all registered hooks."""
    for h in self._hooks:
        h.remove()
    self._hooks.clear()

__call__

__call__(pred: Tensor, yb: Tensor, xb: Tensor) -> torch.Tensor

Compute unit-norm regularization loss from captured hook output.

Source code in tsfast/quaternions/aux_losses.py
def __call__(self, pred: torch.Tensor, yb: torch.Tensor, xb: torch.Tensor) -> torch.Tensor:
    """Compute unit-norm regularization loss from captured hook output."""
    if self._captured is None or self.reg_unit == 0.0:
        return torch.tensor(0.0, device=pred.device)

    h = self._captured.float()
    l_a = float(self.reg_unit) * ((1 - h.norm(dim=-1)) ** 2).mean()
    return l_a

QuaternionAugmentation

QuaternionAugmentation(inp_groups: list[list[int]])

Apply random quaternion rotation to grouped signals during training.

Each call samples a new random quaternion and applies it to all specified signal groups. Groups of size 4 are rotated as quaternions, groups of size 3 are rotated as vectors.

Parameters:

Name Type Description Default
inp_groups list[list[int]]

list of [start, end] index pairs defining signal groups (groups of size 4 are rotated as quaternions, size 3 as vectors).

required
Source code in tsfast/quaternions/transforms.py
def __init__(self, inp_groups: list[list[int]]):
    self.inp_groups = inp_groups
    for g in inp_groups:
        group_len = g[1] - g[0] + 1
        if group_len != 4 and group_len != 3:
            raise AttributeError

__call__

__call__(xb: Tensor, yb: Tensor) -> tuple[torch.Tensor, torch.Tensor]

Apply random quaternion rotation augmentation.

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (augmented xb, augmented yb).

Source code in tsfast/quaternions/transforms.py
def __call__(self, xb: torch.Tensor, yb: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Apply random quaternion rotation augmentation.

    Returns:
        Tuple of (augmented xb, augmented yb).
    """
    r_quat = rand_quat()

    # Augment input groups
    for g in self.inp_groups:
        tmp = xb[..., g[0] : g[1] + 1]
        if tmp.shape[-1] == 3:
            xb[..., g[0] : g[1] + 1] = rot_vec(tmp, r_quat)
        else:
            xb[..., g[0] : g[1] + 1] = multiplyQuat(tmp, r_quat)

    # Augment target (quaternion rotation)
    yb = multiplyQuat(yb, r_quat)

    return xb, yb

conjQuat

conjQuat(q: Tensor) -> torch.Tensor

Compute the conjugate of a quaternion.

Source code in tsfast/quaternions/ops.py
def conjQuat(q: torch.Tensor) -> torch.Tensor:
    """Compute the conjugate of a quaternion."""
    return q * _conjugate_quaternion.to(q.device).type(q.dtype)

diffQuat

diffQuat(q1: Tensor, q2: Tensor, norm: bool = True) -> torch.Tensor

Compute the difference quaternion between q1 and q2.

Parameters:

Name Type Description Default
q1 Tensor

first quaternion tensor.

required
q2 Tensor

second quaternion tensor.

required
norm bool

whether to normalize inputs before computing the difference.

True
Source code in tsfast/quaternions/ops.py
def diffQuat(q1: torch.Tensor, q2: torch.Tensor, norm: bool = True) -> torch.Tensor:
    """Compute the difference quaternion between q1 and q2.

    Args:
        q1: first quaternion tensor.
        q2: second quaternion tensor.
        norm: whether to normalize inputs before computing the difference.
    """
    if norm:
        nq1 = norm_quaternion(q1)
        nq2 = norm_quaternion(q2)
    else:
        nq1 = q1
        nq2 = q2

    return relativeQuat(nq1, nq2)

inclinationAngle

inclinationAngle(q1: Tensor, q2: Tensor) -> torch.Tensor

Inclination (tilt) angle between two quaternions.

Uses atan2 instead of acos for numerical stability.

Source code in tsfast/quaternions/ops.py
def inclinationAngle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Inclination (tilt) angle between two quaternions.

    Uses ``atan2`` instead of ``acos`` for numerical stability.
    """
    q = diffQuat(q1, q2)
    return 2 * torch.atan2(
        (q[..., 1] ** 2 + q[..., 2] ** 2).sqrt(),
        (q[..., 0] ** 2 + q[..., 3] ** 2).sqrt(),
    )

inclinationAngleAbs

inclinationAngleAbs(q: Tensor) -> torch.Tensor

Absolute inclination angle relative to the identity quaternion.

Uses atan2 instead of acos for numerical stability.

Source code in tsfast/quaternions/ops.py
def inclinationAngleAbs(q: torch.Tensor) -> torch.Tensor:
    """Absolute inclination angle relative to the identity quaternion.

    Uses ``atan2`` instead of ``acos`` for numerical stability.
    """
    q = diffQuat(q, _unit_quaternion[None, :].to(q.device))
    return 2 * torch.atan2(
        (q[..., 1] ** 2 + q[..., 2] ** 2).sqrt(),
        (q[..., 0] ** 2 + q[..., 3] ** 2).sqrt(),
    )

multiplyQuat

multiplyQuat(q1: Tensor, q2: Tensor) -> torch.Tensor

Multiply two quaternions element-wise.

Source code in tsfast/quaternions/ops.py
def multiplyQuat(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Multiply two quaternions element-wise."""
    o1 = q1[..., 0] * q2[..., 0] - q1[..., 1] * q2[..., 1] - q1[..., 2] * q2[..., 2] - q1[..., 3] * q2[..., 3]
    o2 = q1[..., 0] * q2[..., 1] + q1[..., 1] * q2[..., 0] + q1[..., 2] * q2[..., 3] - q1[..., 3] * q2[..., 2]
    o3 = q1[..., 0] * q2[..., 2] - q1[..., 1] * q2[..., 3] + q1[..., 2] * q2[..., 0] + q1[..., 3] * q2[..., 1]
    o4 = q1[..., 0] * q2[..., 3] + q1[..., 1] * q2[..., 2] - q1[..., 2] * q2[..., 1] + q1[..., 3] * q2[..., 0]
    return torch.stack([o1, o2, o3, o4], dim=-1)

norm_quaternion

norm_quaternion(q: Tensor) -> torch.Tensor

Normalize quaternions to unit norm.

Source code in tsfast/quaternions/ops.py
def norm_quaternion(q: torch.Tensor) -> torch.Tensor:
    """Normalize quaternions to unit norm."""
    return q / q.norm(p=2, dim=-1)[..., None]

pitchAngle

pitchAngle(q1: Tensor, q2: Tensor) -> torch.Tensor

Euler pitch angle of the difference quaternion.

Uses atan2(sin, cos) instead of asin for numerical stability near gimbal lock (+/-pi/2).

Source code in tsfast/quaternions/ops.py
def pitchAngle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Euler pitch angle of the difference quaternion.

    Uses ``atan2(sin, cos)`` instead of ``asin`` for numerical stability
    near gimbal lock (``+/-pi/2``).
    """
    q = diffQuat(q1, q2)
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    sin_p = (2.0 * (w * y - z * x)).clamp(-1.0, 1.0)
    cos_p = (1.0 - sin_p**2).sqrt()
    return torch.atan2(sin_p, cos_p)

quatFromAngleAxis

quatFromAngleAxis(angle: Tensor, axis: Tensor) -> torch.Tensor

Create quaternions from angle-axis representation.

Parameters:

Name Type Description Default
angle Tensor

rotation angles, shape (N,) or (1,).

required
axis Tensor

rotation axes, shape (3,) or (N, 3) or (1, 3).

required
Source code in tsfast/quaternions/ops.py
def quatFromAngleAxis(angle: torch.Tensor, axis: torch.Tensor) -> torch.Tensor:
    """Create quaternions from angle-axis representation.

    Args:
        angle: rotation angles, shape (N,) or (1,).
        axis: rotation axes, shape (3,) or (N, 3) or (1, 3).
    """
    if len(axis.shape) == 2:
        N = max(angle.shape[0], axis.shape[0])
        assert angle.shape in ((1,), (N,))
        assert axis.shape == (N, 3) or axis.shape == (1, 3)
    else:
        N = angle.shape[0]
        assert angle.shape == (N,)
        assert axis.shape == (3,)
        axis = axis[None, :]

    axis = axis / torch.norm(axis, dim=1)[:, None]
    quat = torch.cat([torch.cos(angle / 2)[:, None], axis * torch.sin(angle / 2)[:, None]], dim=-1)
    return quat

quatInterp

quatInterp(quat: Tensor, ind: Tensor, extend: bool = False) -> torch.Tensor

Interpolate quaternions at non-integer indices using Slerp.

Sampling indices are in the range 0..N-1. For values outside this range, depending on extend, the first/last element or NaN is returned.

Parameters:

Name Type Description Default
quat Tensor

input quaternions, shape (N(xB)x4).

required
ind Tensor

sampling indices, shape (M,).

required
extend bool

if true, extend input by repeating first/last value.

False

Returns:

Type Description
Tensor

Interpolated quaternions, shape (Mx4).

Source code in tsfast/quaternions/ops.py
def quatInterp(quat: torch.Tensor, ind: torch.Tensor, extend: bool = False) -> torch.Tensor:
    """Interpolate quaternions at non-integer indices using Slerp.

    Sampling indices are in the range 0..N-1. For values outside this range,
    depending on ``extend``, the first/last element or NaN is returned.

    Args:
        quat: input quaternions, shape (N(xB)x4).
        ind: sampling indices, shape (M,).
        extend: if true, extend input by repeating first/last value.

    Returns:
        Interpolated quaternions, shape (Mx4).
    """
    N = quat.shape[0]
    M = ind.shape[0]
    assert quat.shape[-1] == 4
    assert ind.shape == (M,)

    ind = ind.to(quat.device)
    ind0 = torch.clamp(torch.floor(ind).type(torch.long), 0, N - 1)
    ind1 = torch.clamp(torch.ceil(ind).type(torch.long), 0, N - 1)

    q0 = quat[ind0].type(torch.float64)
    q1 = quat[ind1].type(torch.float64)
    q_1_0 = diffQuat(q0, q1)

    # normalize the quaternion for positive w component to ensure
    # that the angle will be [0, 180deg]
    invert_sign_ind = q_1_0[..., 0] < 0
    q_1_0[invert_sign_ind] = -q_1_0[invert_sign_ind]

    angle = 2 * torch.atan2(q_1_0[..., 1:].norm(dim=-1), q_1_0[..., 0])
    axis = q_1_0[..., 1:]

    # copy over (almost) direct hits
    direct_ind = angle < 1e-06
    quat_out = torch.empty_like(q0)
    quat_out[direct_ind] = q0[direct_ind]

    interp_ind = ~direct_ind
    t01 = ind - ind0
    if len(quat.shape) == 3:
        t01 = t01[:, None]  # extend shape if batches are part of the tensor
    q_t_0 = quatFromAngleAxis((t01 * angle)[interp_ind], axis[interp_ind])
    quat_out[interp_ind] = multiplyQuat(q0[interp_ind], q_t_0)

    if not extend:
        quat_out[ind < 0] = np.nan
        quat_out[ind > N - 1] = np.nan

    return quat_out.type_as(quat)

rad2deg

rad2deg(t: Tensor) -> torch.Tensor

Convert radians to degrees.

Source code in tsfast/quaternions/ops.py
def rad2deg(t: torch.Tensor) -> torch.Tensor:
    """Convert radians to degrees."""
    return 180.0 * t / _pi.to(t.device).type(t.dtype)

rand_quat

rand_quat() -> torch.Tensor

Generate a random unit quaternion.

Source code in tsfast/quaternions/ops.py
def rand_quat() -> torch.Tensor:
    """Generate a random unit quaternion."""
    q = torch.rand((4)) * 2 - 1
    q /= q.norm()
    return q

relativeAngle

relativeAngle(q1: Tensor, q2: Tensor) -> torch.Tensor

Full rotation angle between two quaternions.

Uses atan2 instead of acos for numerical stability.

Source code in tsfast/quaternions/ops.py
def relativeAngle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Full rotation angle between two quaternions.

    Uses ``atan2`` instead of ``acos`` for numerical stability.
    """
    q = diffQuat(q1, q2)
    return 2 * torch.atan2(q[..., 1:].norm(dim=-1), q[..., 0].abs())

relativeQuat

relativeQuat(q1: Tensor, q2: Tensor) -> torch.Tensor

Compute the relative quaternion as quat1*inv(quat2).

Source code in tsfast/quaternions/ops.py
def relativeQuat(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Compute the relative quaternion as quat1*inv(quat2)."""

    o1 = q1[..., 0] * q2[..., 0] + q1[..., 1] * q2[..., 1] + q1[..., 2] * q2[..., 2] + q1[..., 3] * q2[..., 3]
    o2 = -q1[..., 0] * q2[..., 1] + q1[..., 1] * q2[..., 0] - q1[..., 2] * q2[..., 3] + q1[..., 3] * q2[..., 2]
    o3 = -q1[..., 0] * q2[..., 2] + q1[..., 1] * q2[..., 3] + q1[..., 2] * q2[..., 0] - q1[..., 3] * q2[..., 1]
    o4 = -q1[..., 0] * q2[..., 3] - q1[..., 1] * q2[..., 2] + q1[..., 2] * q2[..., 1] + q1[..., 3] * q2[..., 0]

    return torch.stack([o1, o2, o3, o4], dim=-1)

rollAngle

rollAngle(q1: Tensor, q2: Tensor) -> torch.Tensor

Compute the roll angle between two quaternions.

Source code in tsfast/quaternions/ops.py
def rollAngle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Compute the roll angle between two quaternions."""
    q = diffQuat(q1, q2)
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    t0 = +2.0 * (w * x + y * z)
    t1 = +1.0 - 2.0 * (x * x + y * y)
    return torch.atan2(t0, t1)

rot_vec

rot_vec(v: Tensor, q: Tensor) -> torch.Tensor

Rotate a 3D vector by a quaternion.

Source code in tsfast/quaternions/ops.py
def rot_vec(v: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
    """Rotate a 3D vector by a quaternion."""
    v = F.pad(v, (1, 0), "constant", 0)
    return multiplyQuat(conjQuat(q), multiplyQuat(v, q))[..., 1:]

safe_acos

safe_acos(t: Tensor, eps: float = 4e-08) -> torch.Tensor

Numerically stable variant of arccosine.

Source code in tsfast/quaternions/ops.py
def safe_acos(t: torch.Tensor, eps: float = 4e-8) -> torch.Tensor:
    """Numerically stable variant of arccosine."""
    return t.clamp(-1.0 + eps, 1.0 - eps).acos()

safe_acos_double

safe_acos_double(t: Tensor, eps: float = 1e-16) -> torch.Tensor

Numerically stable arccosine using float64 internally for accuracy.

Source code in tsfast/quaternions/ops.py
def safe_acos_double(t: torch.Tensor, eps: float = 1e-16) -> torch.Tensor:
    """Numerically stable arccosine using float64 internally for accuracy."""
    try:
        return t.type(torch.float64).clamp(-1.0 + eps, 1.0 - eps).acos().type(t.dtype)
    except TypeError as e:
        warnings.warn(
            f"Float64 precision not supported on {t.device} device. Falling back to float32. This may reduce numerical accuracy of quaternion operations. Error: {e}"
        )
        return t.clamp(-1.0 + 1e-6, 1.0 - 1e-6).acos()

abs_inclination

abs_inclination(q1: Tensor, q2: Tensor) -> torch.Tensor

Mean absolute inclination angle.

Source code in tsfast/quaternions/losses.py
def abs_inclination(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Mean absolute inclination angle."""
    return inclination_angle(q1, q2).abs().mean()

abs_rel_angle

abs_rel_angle(q1: Tensor, q2: Tensor) -> torch.Tensor

Mean absolute relative angle.

Source code in tsfast/quaternions/losses.py
def abs_rel_angle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Mean absolute relative angle."""
    return rel_angle(q1, q2).abs().mean()

angle_loss

angle_loss(q1: Tensor, q2: Tensor) -> torch.Tensor

Per-element absolute angle error from difference quaternion w component.

Source code in tsfast/quaternions/losses.py
def angle_loss(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Per-element absolute angle error from difference quaternion w component."""
    q = diffQuat(q1, q2)
    return (q[..., 0] - 1).abs()

angle_loss_opt

angle_loss_opt(q1: Tensor, q2: Tensor) -> torch.Tensor

Per-element absolute angle error (optimized, no full quaternion multiply).

Source code in tsfast/quaternions/losses.py
def angle_loss_opt(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Per-element absolute angle error (optimized, no full quaternion multiply)."""
    q1 = norm_quaternion(q1)
    q2 = norm_quaternion(q2)
    q2 = conjQuat(q2)
    q = q1[..., 0] * q2[..., 0] - q1[..., 1] * q2[..., 1] - q1[..., 2] * q2[..., 2] - q1[..., 3] * q2[..., 3]
    return (q - 1).abs()

deg_rmse

deg_rmse(inp: Tensor, targ: Tensor) -> torch.Tensor

RMSE metric converted to degrees.

Source code in tsfast/quaternions/losses.py
def deg_rmse(inp: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
    """RMSE metric converted to degrees."""
    from ..training import fun_rmse

    return rad2deg(fun_rmse(inp, targ))

inclination_angle

inclination_angle(q1: Tensor, q2: Tensor) -> torch.Tensor

Per-element inclination angle.

Source code in tsfast/quaternions/losses.py
def inclination_angle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Per-element inclination angle."""
    return inclinationAngle(q1, q2)

inclination_error

inclination_error(q1: Tensor, q2: Tensor) -> torch.Tensor

Per-element inclination error from difference quaternion.

Source code in tsfast/quaternions/losses.py
def inclination_error(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Per-element inclination error from difference quaternion."""
    q = diffQuat(q1, q2)
    return (q[..., 3] ** 2 + q[..., 0] ** 2).sqrt() - 1

inclination_loss

inclination_loss(q1: Tensor, q2: Tensor) -> torch.Tensor

RMS inclination error.

Source code in tsfast/quaternions/losses.py
def inclination_loss(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """RMS inclination error."""
    return inclination_error(q1, q2).pow(2).mean().sqrt()

inclination_loss_abs

inclination_loss_abs(q1: Tensor, q2: Tensor) -> torch.Tensor

Mean absolute inclination error.

Source code in tsfast/quaternions/losses.py
def inclination_loss_abs(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Mean absolute inclination error."""
    return inclination_error(q1, q2).abs().mean()

inclination_loss_smooth

inclination_loss_smooth(q1: Tensor, q2: Tensor) -> torch.Tensor

Smooth L1 inclination error.

Source code in tsfast/quaternions/losses.py
def inclination_loss_smooth(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Smooth L1 inclination error."""
    return _smooth_l1(inclination_error(q1, q2))

inclination_loss_squared

inclination_loss_squared(q1: Tensor, q2: Tensor) -> torch.Tensor

Mean squared inclination error.

Source code in tsfast/quaternions/losses.py
def inclination_loss_squared(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Mean squared inclination error."""
    return inclination_error(q1, q2).pow(2).mean()

mean_inclination_deg

mean_inclination_deg(q1: Tensor, q2: Tensor) -> torch.Tensor

Mean inclination angle in degrees.

Source code in tsfast/quaternions/losses.py
def mean_inclination_deg(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Mean inclination angle in degrees."""
    return rad2deg(inclination_angle(q1, q2)).mean()

mean_rel_angle_deg

mean_rel_angle_deg(q1: Tensor, q2: Tensor) -> torch.Tensor

Mean relative angle in degrees.

Source code in tsfast/quaternions/losses.py
def mean_rel_angle_deg(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Mean relative angle in degrees."""
    return rad2deg(rel_angle(q1, q2).mean())

ms_inclination

ms_inclination(q1: Tensor, q2: Tensor) -> torch.Tensor

Mean squared inclination angle.

Source code in tsfast/quaternions/losses.py
def ms_inclination(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Mean squared inclination angle."""
    return inclination_angle(q1, q2).pow(2).mean()

ms_rel_angle

ms_rel_angle(q1: Tensor, q2: Tensor) -> torch.Tensor

Mean squared relative angle.

Source code in tsfast/quaternions/losses.py
def ms_rel_angle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Mean squared relative angle."""
    return rel_angle(q1, q2).pow(2).mean()

pitch_angle

pitch_angle(q1: Tensor, q2: Tensor) -> torch.Tensor

Per-element pitch angle.

Source code in tsfast/quaternions/losses.py
def pitch_angle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Per-element pitch angle."""
    return pitchAngle(q1, q2)

rel_angle

rel_angle(q1: Tensor, q2: Tensor) -> torch.Tensor

Per-element relative angle.

Source code in tsfast/quaternions/losses.py
def rel_angle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Per-element relative angle."""
    return relativeAngle(q1, q2)

rms_inclination

rms_inclination(q1: Tensor, q2: Tensor) -> torch.Tensor

RMS inclination angle.

Source code in tsfast/quaternions/losses.py
def rms_inclination(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """RMS inclination angle."""
    return inclination_angle(q1, q2).pow(2).mean().sqrt()

rms_inclination_deg

rms_inclination_deg(q1: Tensor, q2: Tensor) -> torch.Tensor

RMS inclination angle in degrees.

Source code in tsfast/quaternions/losses.py
def rms_inclination_deg(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """RMS inclination angle in degrees."""
    return rad2deg(inclination_angle(q1, q2)).pow(2).mean().sqrt()

rms_pitch_deg

rms_pitch_deg(q1: Tensor, q2: Tensor) -> torch.Tensor

RMS pitch angle in degrees.

Source code in tsfast/quaternions/losses.py
def rms_pitch_deg(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """RMS pitch angle in degrees."""
    return _deg_rms(pitch_angle(q1, q2))

rms_rel_angle_deg

rms_rel_angle_deg(q1: Tensor, q2: Tensor) -> torch.Tensor

RMS relative angle in degrees.

Source code in tsfast/quaternions/losses.py
def rms_rel_angle_deg(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """RMS relative angle in degrees."""
    return _deg_rms(rel_angle(q1, q2))

rms_roll_deg

rms_roll_deg(q1: Tensor, q2: Tensor) -> torch.Tensor

RMS roll angle in degrees.

Source code in tsfast/quaternions/losses.py
def rms_roll_deg(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """RMS roll angle in degrees."""
    return _deg_rms(roll_angle(q1, q2))

roll_angle

roll_angle(q1: Tensor, q2: Tensor) -> torch.Tensor

Per-element roll angle.

Source code in tsfast/quaternions/losses.py
def roll_angle(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Per-element roll angle."""
    return rollAngle(q1, q2)

smooth_inclination

smooth_inclination(q1: Tensor, q2: Tensor) -> torch.Tensor

Smooth L1 inclination angle.

Source code in tsfast/quaternions/losses.py
def smooth_inclination(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
    """Smooth L1 inclination angle."""
    return _smooth_l1(inclination_angle(q1, q2))

augmentation_groups

augmentation_groups(u_groups: list[int]) -> list[list[int]]

Convert channel group sizes into start/end index pairs.

Parameters:

Name Type Description Default
u_groups list[int]

list of group sizes (number of channels per group).

required
Source code in tsfast/quaternions/transforms.py
def augmentation_groups(u_groups: list[int]) -> list[list[int]]:
    """Convert channel group sizes into start/end index pairs.

    Args:
        u_groups: list of group sizes (number of channels per group).
    """
    u_groups = np.cumsum([0] + u_groups)
    return [[u_groups[i], u_groups[i + 1] - 1] for i in range(len(u_groups) - 1)]

multiplyQuat_np

multiplyQuat_np(q1: ndarray, q2: ndarray) -> np.ndarray

Multiply two quaternions element-wise (numpy).

Source code in tsfast/quaternions/numpy_ops.py
def multiplyQuat_np(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
    """Multiply two quaternions element-wise (numpy)."""
    if isinstance(q1, np.ndarray) and q1.shape == (4,):
        q1 = q1[np.newaxis]  # convert to 1x4 matrix
        shape = q2.shape
    elif isinstance(q1, np.ndarray) and q1.shape == (1, 4):
        shape = q2.shape
    elif isinstance(q2, np.ndarray) and q2.shape == (4,):
        q2 = q2[np.newaxis]  # convert to 1x4 matrix
        shape = q1.shape
    elif isinstance(q2, np.ndarray) and q2.shape == (1, 4):
        shape = q1.shape
    else:
        assert q1.shape == q2.shape
        shape = q1.shape
    output = np.zeros(shape=shape)
    output[:, 0] = q1[:, 0] * q2[:, 0] - q1[:, 1] * q2[:, 1] - q1[:, 2] * q2[:, 2] - q1[:, 3] * q2[:, 3]
    output[:, 1] = q1[:, 0] * q2[:, 1] + q1[:, 1] * q2[:, 0] + q1[:, 2] * q2[:, 3] - q1[:, 3] * q2[:, 2]
    output[:, 2] = q1[:, 0] * q2[:, 2] - q1[:, 1] * q2[:, 3] + q1[:, 2] * q2[:, 0] + q1[:, 3] * q2[:, 1]
    output[:, 3] = q1[:, 0] * q2[:, 3] + q1[:, 1] * q2[:, 2] - q1[:, 2] * q2[:, 1] + q1[:, 3] * q2[:, 0]
    return output

quatFromAngleAxis_np

quatFromAngleAxis_np(angle: ndarray, axis: ndarray) -> np.ndarray

Create quaternions from angle-axis representation (numpy).

If angle is 0, the output will be an identity quaternion. If axis is a zero vector, a ValueError will be raised unless the corresponding angle is 0.

Parameters:

Name Type Description Default
angle ndarray

scalar or N angles in radians.

required
axis ndarray

rotation axes, shape (3,) or (Nx3) or (1x3).

required

Returns:

Type Description
ndarray

Quaternion array, shape (Nx4) or (1x4).

Source code in tsfast/quaternions/numpy_ops.py
def quatFromAngleAxis_np(angle: np.ndarray, axis: np.ndarray) -> np.ndarray:
    """Create quaternions from angle-axis representation (numpy).

    If angle is 0, the output will be an identity quaternion. If axis is a
    zero vector, a ValueError will be raised unless the corresponding angle
    is 0.

    Args:
        angle: scalar or N angles in radians.
        axis: rotation axes, shape (3,) or (Nx3) or (1x3).

    Returns:
        Quaternion array, shape (Nx4) or (1x4).
    """

    angle = np.asarray(angle, np.float64)
    axis = np.asarray(axis, np.float64)

    is1D = (angle.shape == tuple() or angle.shape == (1,)) and axis.shape == (3,)

    if angle.shape == tuple():
        angle = angle.reshape(1)  # equivalent to np.atleast_1d
    if axis.shape == (3,):
        axis = axis.reshape((1, 3))

    N = max(angle.shape[0], axis.shape[0])

    # for (1x1) case
    if angle.shape == (1, 1):
        angle = angle.ravel()

    assert angle.shape == (N,) or angle.shape == (1,), f"invalid angle shape: {angle.shape}"
    assert axis.shape == (N, 3) or axis.shape == (1, 3), f"invalid axis shape: {axis.shape}"

    angle_brodcasted = np.broadcast_to(angle, (N,))
    axis_brodcasted = np.broadcast_to(axis, (N, 3))

    norm = np.linalg.norm(axis_brodcasted, axis=1)

    identity = norm < np.finfo(np.float64).eps

    q = np.zeros((N, 4), np.float64)
    q[identity] = np.array([1, 0, 0, 0])
    q[~identity] = np.concatenate(
        (
            np.cos(angle_brodcasted[~identity][:, np.newaxis] / 2),
            axis_brodcasted[~identity]
            * np.array(np.sin(angle_brodcasted[~identity] / 2.0) / norm[~identity])[:, np.newaxis],
        ),
        axis=1,
    )

    if is1D:
        q = q.reshape((4,))

    return q

quatInterp_np

quatInterp_np(quat: ndarray, ind: ndarray, extend: bool = True) -> np.ndarray

Interpolate quaternions at non-integer indices using Slerp (numpy).

Sampling indices are in the range 0..N-1. For values outside this range, depending on extend, the first/last element or NaN is returned.

Parameters:

Name Type Description Default
quat ndarray

input quaternions, shape (Nx4).

required
ind ndarray

sampling indices, shape (M,).

required
extend bool

if true, extend input by repeating first/last value.

True

Returns:

Type Description
ndarray

Interpolated quaternions, shape (Mx4).

Source code in tsfast/quaternions/numpy_ops.py
def quatInterp_np(quat: np.ndarray, ind: np.ndarray, extend: bool = True) -> np.ndarray:
    """Interpolate quaternions at non-integer indices using Slerp (numpy).

    Sampling indices are in the range 0..N-1. For values outside this range,
    depending on ``extend``, the first/last element or NaN is returned.

    Args:
        quat: input quaternions, shape (Nx4).
        ind: sampling indices, shape (M,).
        extend: if true, extend input by repeating first/last value.

    Returns:
        Interpolated quaternions, shape (Mx4).
    """
    ind = np.atleast_1d(ind)
    N = quat.shape[0]
    M = ind.shape[0]
    assert quat.shape == (N, 4)
    assert ind.shape == (M,)

    ind0 = np.clip(np.floor(ind).astype(int), 0, N - 1)
    ind1 = np.clip(np.ceil(ind).astype(int), 0, N - 1)

    q0 = quat[ind0]
    q1 = quat[ind1]
    q_1_0 = relativeQuat_np(q0, q1)

    # normalize the quaternion for positive w component to ensure
    # that the angle will be [0, 180deg]
    invert_sign_ind = q_1_0[:, 0] < 0
    q_1_0[invert_sign_ind] = -q_1_0[invert_sign_ind]

    angle = 2 * np.arccos(np.clip(q_1_0[:, 0], -1, 1))
    axis = q_1_0[:, 1:]

    # copy over (almost) direct hits
    with np.errstate(invalid="ignore"):
        direct_ind = angle < 1e-06
    quat_out = np.empty_like(q0)
    quat_out[direct_ind] = q0[direct_ind]

    interp_ind = ~direct_ind
    t01 = ind - ind0
    q_t_0 = quatFromAngleAxis_np((t01 * angle)[interp_ind], axis[interp_ind])
    quat_out[interp_ind] = multiplyQuat_np(q0[interp_ind], q_t_0)

    if not extend:
        quat_out[ind < 0] = np.nan
        quat_out[ind > N - 1] = np.nan

    return quat_out

relativeQuat_np

relativeQuat_np(q1: ndarray, q2: ndarray) -> np.ndarray

Compute the relative quaternion as inv(quat1)*quat2 (numpy).

Source code in tsfast/quaternions/numpy_ops.py
def relativeQuat_np(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
    """Compute the relative quaternion as inv(quat1)*quat2 (numpy)."""
    if isinstance(q1, np.ndarray) and q1.shape == (4,):
        q1 = q1[np.newaxis]  # convert to 1x4 matrix
        shape = q2.shape
    elif isinstance(q1, np.ndarray) and q1.shape == (1, 4):
        shape = q2.shape
    elif isinstance(q2, np.ndarray) and q2.shape == (4,):
        q2 = q2[np.newaxis]  # convert to 1x4 matrix
        shape = q1.shape
    elif isinstance(q2, np.ndarray) and q2.shape == (1, 4):
        shape = q1.shape
    else:
        assert q1.shape == q2.shape
        shape = q1.shape
    output = np.zeros(shape=shape)
    output[:, 0] = q1[:, 0] * q2[:, 0] + q1[:, 1] * q2[:, 1] + q1[:, 2] * q2[:, 2] + q1[:, 3] * q2[:, 3]
    output[:, 1] = q1[:, 0] * q2[:, 1] - q1[:, 1] * q2[:, 0] - q1[:, 2] * q2[:, 3] + q1[:, 3] * q2[:, 2]
    output[:, 2] = q1[:, 0] * q2[:, 2] + q1[:, 1] * q2[:, 3] - q1[:, 2] * q2[:, 0] - q1[:, 3] * q2[:, 1]
    output[:, 3] = q1[:, 0] * q2[:, 3] - q1[:, 1] * q2[:, 2] + q1[:, 2] * q2[:, 1] - q1[:, 3] * q2[:, 0]
    return output

plot_quaternion_inclination

plot_quaternion_inclination(axs: list, in_sig: Tensor, targ_sig: Tensor, out_sig: Tensor | None = None, **kwargs)

Plot quaternion inclination target, prediction, and error.

Parameters:

Name Type Description Default
axs list

list of matplotlib axes to plot on.

required
in_sig Tensor

input signal tensor.

required
targ_sig Tensor

target quaternion tensor.

required
out_sig Tensor | None

predicted quaternion tensor, or None for batch display.

None
Source code in tsfast/quaternions/viz.py
def plot_quaternion_inclination(
    axs: list, in_sig: torch.Tensor, targ_sig: torch.Tensor, out_sig: torch.Tensor | None = None, **kwargs
):
    """Plot quaternion inclination target, prediction, and error.

    Args:
        axs: list of matplotlib axes to plot on.
        in_sig: input signal tensor.
        targ_sig: target quaternion tensor.
        out_sig: predicted quaternion tensor, or None for batch display.
    """
    axs[0].plot(rad2deg(inclinationAngleAbs(targ_sig)).detach().cpu().numpy())
    axs[0].label_outer()
    axs[0].legend(["y"])
    axs[0].set_ylabel("inclination[deg]")

    if out_sig is not None:
        axs[0].plot(rad2deg(inclinationAngleAbs(out_sig)).detach().cpu().numpy())
        axs[0].legend(["y", "y-hat"])
        axs[1].plot(rad2deg(inclinationAngle(out_sig, targ_sig)).detach().cpu().numpy())
        axs[1].label_outer()
        axs[1].set_ylabel("error[deg]")
        if "ref" in kwargs:
            axs[1].plot(rad2deg(inclinationAngle(targ_sig, kwargs["ref"])).detach().cpu().numpy())
            axs[1].legend(["y-hat", "y_ref"])

    axs[-1].plot(in_sig)

plot_quaternion_rel_angle

plot_quaternion_rel_angle(axs: list, in_sig: Tensor, targ_sig: Tensor, out_sig: Tensor | None = None, **kwargs)

Plot relative quaternion angle target, prediction, and error.

Parameters:

Name Type Description Default
axs list

list of matplotlib axes to plot on.

required
in_sig Tensor

input signal tensor.

required
targ_sig Tensor

target quaternion tensor.

required
out_sig Tensor | None

predicted quaternion tensor, or None for batch display.

None
Source code in tsfast/quaternions/viz.py
def plot_quaternion_rel_angle(
    axs: list, in_sig: torch.Tensor, targ_sig: torch.Tensor, out_sig: torch.Tensor | None = None, **kwargs
):
    """Plot relative quaternion angle target, prediction, and error.

    Args:
        axs: list of matplotlib axes to plot on.
        in_sig: input signal tensor.
        targ_sig: target quaternion tensor.
        out_sig: predicted quaternion tensor, or None for batch display.
    """
    first_targ = targ_sig[0].repeat(targ_sig.shape[0], 1)
    axs[0].plot(rad2deg(relativeAngle(first_targ, targ_sig)).detach().cpu().numpy())
    axs[0].label_outer()
    axs[0].legend(["y"])
    axs[0].set_ylabel("angle[deg]")

    if out_sig is not None:
        axs[0].plot(rad2deg(relativeAngle(first_targ, out_sig)).detach().cpu().numpy())
        axs[0].legend(["y", "y-hat"])
        axs[1].plot(rad2deg(relativeAngle(out_sig, targ_sig)).detach().cpu().numpy())
        axs[1].label_outer()
        axs[1].set_ylabel("error[deg]")

    axs[-1].plot(in_sig)

plot_scalar_inclination

plot_scalar_inclination(axs: list, in_sig: Tensor, targ_sig: Tensor, out_sig: Tensor | None = None, **kwargs)

Plot scalar inclination target, prediction, and error.

Parameters:

Name Type Description Default
axs list

list of matplotlib axes to plot on.

required
in_sig Tensor

input signal tensor.

required
targ_sig Tensor

target inclination tensor.

required
out_sig Tensor | None

predicted inclination tensor, or None for batch display.

None
Source code in tsfast/quaternions/viz.py
def plot_scalar_inclination(
    axs: list, in_sig: torch.Tensor, targ_sig: torch.Tensor, out_sig: torch.Tensor | None = None, **kwargs
):
    """Plot scalar inclination target, prediction, and error.

    Args:
        axs: list of matplotlib axes to plot on.
        in_sig: input signal tensor.
        targ_sig: target inclination tensor.
        out_sig: predicted inclination tensor, or None for batch display.
    """
    axs[0].plot(rad2deg(targ_sig).detach().cpu().numpy())
    axs[0].label_outer()
    axs[0].set_ylabel("inclination[deg]")

    if out_sig is not None:
        axs[0].plot(rad2deg(out_sig).detach().cpu().numpy())
        axs[0].legend(["y", "y-hat"])
        axs[1].plot(rad2deg(targ_sig - out_sig).detach().cpu().numpy())
        axs[1].label_outer()
        axs[1].set_ylabel("error[deg]")

    axs[-1].plot(in_sig)