Skip to content

Losses and Metrics

Loss functions and metrics for training.

nan_mean

nan_mean(fn: Callable, fill: list | float) -> Callable

Wrap a per-element loss into a NaN-safe, CUDA-graph-compatible mean.

NaN targets are replaced with fill (static shapes preserved), and a masked mean ensures only valid positions contribute to the gradient.

Parameters:

Name Type Description Default
fn Callable

per-element function (inp, targ) -> Tensor

required
fill list | float

value to substitute for NaN targets

required
Source code in tsfast/training/losses.py
def nan_mean(fn: Callable, fill: list | float) -> Callable:
    """Wrap a per-element loss into a NaN-safe, CUDA-graph-compatible mean.

    NaN targets are replaced with *fill* (static shapes preserved), and a
    masked mean ensures only valid positions contribute to the gradient.

    Args:
        fn: per-element function ``(inp, targ) -> Tensor``
        fill: value to substitute for NaN targets
    """
    _cache: dict[tuple, Tensor] = {}

    @functools.wraps(fn)
    def wrapper(inp: Tensor, targ: Tensor) -> Tensor:
        key = (targ.device, targ.dtype)
        if key not in _cache:
            _cache[key] = torch.as_tensor(fill, dtype=targ.dtype, device=targ.device)
        mask = ~torch.isnan(targ).any(dim=-1)
        targ = torch.where(mask.unsqueeze(-1), targ, _cache[key])
        elem = fn(inp, targ)
        if elem.dim() > mask.dim():
            mask = mask.unsqueeze(-1)
        return (elem * mask).sum() / mask.expand_as(elem).sum()

    return wrapper

mse

mse(inp: Tensor, targ: Tensor) -> Tensor

Mean squared error.

Source code in tsfast/training/losses.py
def mse(inp: Tensor, targ: Tensor) -> Tensor:
    """Mean squared error."""
    return (inp - targ).pow(2).mean()

ignore_nan

ignore_nan(func: Callable) -> Callable

Decorator that removes NaN samples from (inp, targ) before computing a loss.

A sample is removed if any feature in the target is NaN. Reduces tensors to a flat array.

Parameters:

Name Type Description Default
func Callable

loss function with signature (inp, targ) -> Tensor

required
Source code in tsfast/training/losses.py
def ignore_nan(func: Callable) -> Callable:
    """Decorator that removes NaN samples from (inp, targ) before computing a loss.

    A sample is removed if any feature in the target is NaN.
    Reduces tensors to a flat array.

    Args:
        func: loss function with signature (inp, targ) -> Tensor
    """

    @functools.wraps(func)
    def wrapper(inp: Tensor, targ: Tensor) -> Tensor:
        mask = ~torch.isnan(targ).any(dim=-1)
        return func(inp[mask], targ[mask])

    return wrapper

float64_func

float64_func(func: Callable) -> Callable

Decorator that computes a function in float64 and converts the result back.

Parameters:

Name Type Description Default
func Callable

function to wrap with float64 promotion

required
Source code in tsfast/training/losses.py
def float64_func(func: Callable) -> Callable:
    """Decorator that computes a function in float64 and converts the result back.

    Args:
        func: function to wrap with float64 promotion
    """

    @functools.wraps(func)
    def float64_func_decorator(*args, **kwargs):
        typ = args[0].dtype
        try:
            args = tuple([x.double() if isinstance(x, Tensor) else x for x in args])
            return func(*args, **kwargs).type(typ)
        except TypeError as e:
            if "doesn't support float64" in str(e):
                warnings.warn(
                    f"Float64 precision not supported on {args[0].device} device. Using original precision. This may reduce numerical accuracy. Error: {e}"
                )
                return func(*args, **kwargs)
            else:
                raise

    return float64_func_decorator

cut_loss

cut_loss(fn: Callable, l_cut: int = 0, r_cut: int | None = None) -> Callable

Loss-function modifier that slices the sequence from l_cut to r_cut.

Parameters:

Name Type Description Default
fn Callable

base loss function to wrap

required
l_cut int

left index to start the slice

0
r_cut int | None

right index to end the slice (None keeps the rest)

None
Source code in tsfast/training/losses.py
def cut_loss(fn: Callable, l_cut: int = 0, r_cut: int | None = None) -> Callable:
    """Loss-function modifier that slices the sequence from l_cut to r_cut.

    Args:
        fn: base loss function to wrap
        l_cut: left index to start the slice
        r_cut: right index to end the slice (None keeps the rest)
    """

    @functools.wraps(fn)
    def _inner(input, target):
        return fn(input[:, l_cut:r_cut], target[:, l_cut:r_cut])

    return _inner

norm_loss

norm_loss(fn: Callable, norm_stats, scaler_cls: type | None = None) -> Callable

Loss wrapper that normalizes predictions and targets before computing loss.

Parameters:

Name Type Description Default
fn Callable

base loss function to wrap

required
norm_stats

normalization statistics used to build the scaler

required
scaler_cls type | None

scaler class to use (defaults to StandardScaler)

None
Source code in tsfast/training/losses.py
def norm_loss(fn: Callable, norm_stats, scaler_cls: type | None = None) -> Callable:
    """Loss wrapper that normalizes predictions and targets before computing loss.

    Args:
        fn: base loss function to wrap
        norm_stats: normalization statistics used to build the scaler
        scaler_cls: scaler class to use (defaults to StandardScaler)
    """
    from ..models.scaling import StandardScaler

    if scaler_cls is None:
        scaler_cls = StandardScaler
    scaler = scaler_cls.from_stats(norm_stats)

    @functools.wraps(fn)
    def _inner(input, target):
        scaler.to(input.device)
        return fn(scaler.normalize(input), scaler.normalize(target))

    return _inner

weighted_mae

weighted_mae(input: Tensor, target: Tensor) -> Tensor

Weighted MAE with log-spaced weights decaying along the sequence axis.

Source code in tsfast/training/losses.py
def weighted_mae(input: Tensor, target: Tensor) -> Tensor:
    """Weighted MAE with log-spaced weights decaying along the sequence axis."""
    max_weight = 1.0
    min_weight = 0.1
    seq_len = input.shape[1]

    device = input.device
    compute_device = device
    if device.type == "mps":
        compute_device = "cpu"
        warnings.warn(
            f"torch.logspace not supported on {device} device. Using cpu. This may reduce numerical performance"
        )
    weights = torch.logspace(
        start=torch.log10(torch.tensor(max_weight)),
        end=torch.log10(torch.tensor(min_weight)),
        steps=seq_len,
        device=compute_device,
    ).to(device)

    weights = (weights / weights.sum())[None, :, None]

    return ((input - target).abs() * weights).sum(dim=1).mean()

rand_seq_len_loss

rand_seq_len_loss(fn: Callable, min_idx: int = 1, max_idx: int | None = None, mid_idx: int | None = None) -> Callable

Loss-function modifier that randomly truncates each sequence in the minibatch individually.

Uses a triangular distribution. Slow for very large batch sizes.

Parameters:

Name Type Description Default
fn Callable

base loss function to wrap

required
min_idx int

minimum sequence length

1
max_idx int | None

maximum sequence length (defaults to full sequence)

None
mid_idx int | None

mode of the triangular distribution (defaults to min_idx)

None
Source code in tsfast/training/losses.py
def rand_seq_len_loss(
    fn: Callable, min_idx: int = 1, max_idx: int | None = None, mid_idx: int | None = None
) -> Callable:
    """Loss-function modifier that randomly truncates each sequence in the minibatch individually.

    Uses a triangular distribution. Slow for very large batch sizes.

    Args:
        fn: base loss function to wrap
        min_idx: minimum sequence length
        max_idx: maximum sequence length (defaults to full sequence)
        mid_idx: mode of the triangular distribution (defaults to min_idx)
    """

    @functools.wraps(fn)
    def _inner(input, target):
        bs, seq_len, _ = input.shape
        _max = max_idx if max_idx is not None else seq_len
        _mid = mid_idx if mid_idx is not None else min_idx
        len_list = np.random.triangular(min_idx, _mid, _max, (bs,)).astype(int)
        return torch.stack([fn(input[i, : len_list[i]], target[i, : len_list[i]]) for i in range(bs)]).mean()

    return _inner

fun_rmse

fun_rmse(inp: Tensor, targ: Tensor) -> Tensor

RMSE loss function defined as a plain function.

Source code in tsfast/training/losses.py
def fun_rmse(inp: Tensor, targ: Tensor) -> Tensor:
    """RMSE loss function defined as a plain function."""
    return torch.sqrt(F.mse_loss(inp, targ))

cos_sim_loss

cos_sim_loss(inp: Tensor, targ: Tensor) -> Tensor

Mean cosine similarity loss (1 - cosine similarity).

Source code in tsfast/training/losses.py
def cos_sim_loss(inp: Tensor, targ: Tensor) -> Tensor:
    """Mean cosine similarity loss (1 - cosine similarity)."""
    return (1 - F.cosine_similarity(inp, targ, dim=-1)).mean()

cos_sim_loss_pow

cos_sim_loss_pow(inp: Tensor, targ: Tensor) -> Tensor

Mean squared cosine similarity loss.

Source code in tsfast/training/losses.py
def cos_sim_loss_pow(inp: Tensor, targ: Tensor) -> Tensor:
    """Mean squared cosine similarity loss."""
    return (1 - F.cosine_similarity(inp, targ, dim=-1)).pow(2).mean()

nrmse

nrmse(inp: Tensor, targ: Tensor) -> Tensor

RMSE loss normalized by variance of each target variable.

Source code in tsfast/training/losses.py
def nrmse(inp: Tensor, targ: Tensor) -> Tensor:
    """RMSE loss normalized by variance of each target variable."""
    mse_val = (inp - targ).pow(2).mean(dim=[0, 1])
    var = targ.var(dim=[0, 1])
    return (mse_val / var).sqrt().mean()

nrmse_std

nrmse_std(inp: Tensor, targ: Tensor) -> Tensor

RMSE loss normalized by standard deviation of each target variable.

Source code in tsfast/training/losses.py
def nrmse_std(inp: Tensor, targ: Tensor) -> Tensor:
    """RMSE loss normalized by standard deviation of each target variable."""
    mse_val = (inp - targ).pow(2).mean(dim=[0, 1])
    std = targ.std(dim=[0, 1])
    return (mse_val / std).sqrt().mean()

mean_vaf

mean_vaf(inp: Tensor, targ: Tensor) -> Tensor

Variance accounted for (VAF) metric, returned as a percentage.

Source code in tsfast/training/losses.py
def mean_vaf(inp: Tensor, targ: Tensor) -> Tensor:
    """Variance accounted for (VAF) metric, returned as a percentage."""
    return (1 - ((targ - inp).var() / targ.var())) * 100

zero_loss

zero_loss(pred: Tensor, targ: Tensor) -> Tensor

Always-zero loss that preserves the computation graph.

Source code in tsfast/training/losses.py
def zero_loss(pred: Tensor, targ: Tensor) -> Tensor:
    """Always-zero loss that preserves the computation graph."""
    return (pred * 0).sum()