Skip to content

Auxiliary Losses

Auxiliary loss callables for training.

AuxiliaryLoss

AuxiliaryLoss(loss_func: Callable, alpha: float = 1.0)

Auxiliary loss that applies a loss function to predictions and targets.

Parameters:

Name Type Description Default
loss_func Callable

loss function applied to (pred, targ)

required
alpha float

scaling factor for the auxiliary loss

1.0
Source code in tsfast/training/aux_losses.py
def __init__(self, loss_func: Callable, alpha: float = 1.0):
    self.loss_func = loss_func
    self.alpha = alpha

ActivationRegularizer

ActivationRegularizer(modules: list[Module], alpha: float = 1.0, dim: int | None = None)

Bases: _ActivationHookMixin

L2 penalty on hooked module activations (activation regularization).

Parameters:

Name Type Description Default
modules list[Module]

modules to hook for capturing activations

required
alpha float

coefficient for the L2 penalty

1.0
dim int | None

time axis index; auto-detected from the hooked layer output if None

None
Source code in tsfast/training/aux_losses.py
def __init__(self, modules: list[nn.Module], alpha: float = 1.0, dim: int | None = None):
    self.modules = modules
    self.alpha = alpha
    self.dim = dim
    self._hooks: list[torch.utils.hooks.RemovableHook] = []
    self._out: Tensor | None = None

TemporalActivationRegularizer

TemporalActivationRegularizer(modules: list[Module], beta: float = 1.0, dim: int | None = None)

Bases: _ActivationHookMixin

L2 penalty on consecutive-timestep activation differences (temporal activation regularization).

Parameters:

Name Type Description Default
modules list[Module]

modules to hook for capturing activations

required
beta float

coefficient for the L2 penalty on temporal differences

1.0
dim int | None

time axis index; auto-detected from the hooked layer output if None

None
Source code in tsfast/training/aux_losses.py
def __init__(self, modules: list[nn.Module], beta: float = 1.0, dim: int | None = None):
    self.modules = modules
    self.beta = beta
    self.dim = dim
    self._hooks: list[torch.utils.hooks.RemovableHook] = []
    self._out: Tensor | None = None

FranSysRegularizer

FranSysRegularizer(modules: list[Module], p_state_sync: float = 10000000.0, p_diag_loss: float = 0.0, p_osp_sync: float = 0, p_osp_loss: float = 0, p_tar_loss: float = 0, sync_type: str = 'mse', targ_loss_func: Callable = F.l1_loss, osp_n_skip: int | None = None, model: Module | None = None)

Regularizes FranSys output by syncing diagnosis and prognosis hidden states.

Parameters:

Name Type Description Default
modules list[Module]

modules to hook (diagnosis + prognosis RNNs)

required
p_state_sync float

scaling factor for hidden state sync loss

10000000.0
p_diag_loss float

scaling factor for diagnosis loss through the final layer

0.0
p_osp_sync float

scaling factor for one-step prediction hidden state sync

0
p_osp_loss float

scaling factor for one-step prediction loss

0
p_tar_loss float

scaling factor for temporal activation regularization

0
sync_type str

distance metric for state synchronization

'mse'
targ_loss_func Callable

loss function for target-based regularization

l1_loss
osp_n_skip int | None

elements to skip before one-step prediction (defaults to model.init_sz)

None
model Module | None

explicit FranSys model reference (auto-detected via unwrap_model if None)

None
Source code in tsfast/training/aux_losses.py
def __init__(
    self,
    modules: list[nn.Module],
    p_state_sync: float = 1e7,
    p_diag_loss: float = 0.0,
    p_osp_sync: float = 0,
    p_osp_loss: float = 0,
    p_tar_loss: float = 0,
    sync_type: str = "mse",
    targ_loss_func: Callable = F.l1_loss,
    osp_n_skip: int | None = None,
    model: nn.Module | None = None,
):
    if sync_type not in self._SYNC_FNS:
        raise ValueError(f"Unknown sync_type: {sync_type!r}")
    self.modules = modules
    self.p_state_sync = p_state_sync
    self.p_diag_loss = p_diag_loss
    self.p_osp_sync = p_osp_sync
    self.p_osp_loss = p_osp_loss
    self.p_tar_loss = p_tar_loss
    self.sync_type = sync_type
    self.targ_loss_func = targ_loss_func
    self.osp_n_skip = osp_n_skip
    self.inner_model = model
    self._hooks: list[torch.utils.hooks.RemovableHook] = []
    self._out_diag: Tensor | None = None
    self._out_prog: Tensor | None = None
    self._output_norm = None