Auxiliary Losses¶
Auxiliary loss callables for training.
AuxiliaryLoss ¶
Auxiliary loss that applies a loss function to predictions and targets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_func
|
Callable
|
loss function applied to (pred, targ) |
required |
alpha
|
float
|
scaling factor for the auxiliary loss |
1.0
|
Source code in tsfast/training/aux_losses.py
ActivationRegularizer ¶
Bases: _ActivationHookMixin
L2 penalty on hooked module activations (activation regularization).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
list[Module]
|
modules to hook for capturing activations |
required |
alpha
|
float
|
coefficient for the L2 penalty |
1.0
|
dim
|
int | None
|
time axis index; auto-detected from the hooked layer output if None |
None
|
Source code in tsfast/training/aux_losses.py
TemporalActivationRegularizer ¶
Bases: _ActivationHookMixin
L2 penalty on consecutive-timestep activation differences (temporal activation regularization).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
list[Module]
|
modules to hook for capturing activations |
required |
beta
|
float
|
coefficient for the L2 penalty on temporal differences |
1.0
|
dim
|
int | None
|
time axis index; auto-detected from the hooked layer output if None |
None
|
Source code in tsfast/training/aux_losses.py
FranSysRegularizer ¶
FranSysRegularizer(modules: list[Module], p_state_sync: float = 10000000.0, p_diag_loss: float = 0.0, p_osp_sync: float = 0, p_osp_loss: float = 0, p_tar_loss: float = 0, sync_type: str = 'mse', targ_loss_func: Callable = F.l1_loss, osp_n_skip: int | None = None, model: Module | None = None)
Regularizes FranSys output by syncing diagnosis and prognosis hidden states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
list[Module]
|
modules to hook (diagnosis + prognosis RNNs) |
required |
p_state_sync
|
float
|
scaling factor for hidden state sync loss |
10000000.0
|
p_diag_loss
|
float
|
scaling factor for diagnosis loss through the final layer |
0.0
|
p_osp_sync
|
float
|
scaling factor for one-step prediction hidden state sync |
0
|
p_osp_loss
|
float
|
scaling factor for one-step prediction loss |
0
|
p_tar_loss
|
float
|
scaling factor for temporal activation regularization |
0
|
sync_type
|
str
|
distance metric for state synchronization |
'mse'
|
targ_loss_func
|
Callable
|
loss function for target-based regularization |
l1_loss
|
osp_n_skip
|
int | None
|
elements to skip before one-step prediction (defaults to model.init_sz) |
None
|
model
|
Module | None
|
explicit FranSys model reference (auto-detected via unwrap_model if None) |
None
|