Skip to content

Learner

Learner, TbpttLearner, and Recorder — pure-PyTorch training loop.

Recorder

Recorder()

Stores training history: values[epoch] = [train_loss, valid_loss, *metrics].

Source code in tsfast/training/learner.py
def __init__(self):
    self.values: list[list[float]] = []

Learner

Learner(model: Module, dls: DataLoaders, loss_func: Callable, metrics: list[Callable] | None = None, lr: float = 0.003, opt_func: type = torch.optim.Adam, transforms: list | None = None, augmentations: list | None = None, aux_losses: list | None = None, n_skip: int = 0, grad_clip: float | None = None, plot_fn: Callable | None = None, device: device | None = None)

Pure-PyTorch training loop for time-series models.

Parameters:

Name Type Description Default
model Module

the model to train

required
dls DataLoaders

train/valid/test DataLoaders

required
loss_func Callable

primary loss function

required
metrics list[Callable] | None

list of metric functions (pred, targ) -> scalar

None
lr float

default learning rate

0.003
opt_func type

optimizer class

Adam
transforms list | None

list of (xb, yb) -> (xb, yb) applied to train + valid

None
augmentations list | None

list of (xb, yb) -> (xb, yb) applied to train only

None
aux_losses list | None

list of (pred, yb, xb) -> loss_term added to primary loss

None
n_skip int

number of initial time steps to skip in loss computation

0
grad_clip float | None

maximum gradient norm (None disables clipping)

None
plot_fn Callable | None

plotting function for show_batch/show_results

None
device device | None

target device (auto-detected if None)

None
Source code in tsfast/training/learner.py
def __init__(
    self,
    model: nn.Module,
    dls: DataLoaders,
    loss_func: Callable,
    metrics: list[Callable] | None = None,
    lr: float = 3e-3,
    opt_func: type = torch.optim.Adam,
    transforms: list | None = None,
    augmentations: list | None = None,
    aux_losses: list | None = None,
    n_skip: int = 0,
    grad_clip: float | None = None,
    plot_fn: Callable | None = None,
    device: torch.device | None = None,
):
    self.model = model
    self.dls = dls
    self.loss_func = loss_func
    self.metrics = metrics or []
    self.lr = lr
    self.opt_func = opt_func
    self.transforms = transforms or []
    self.augmentations = augmentations or []
    self.aux_losses = aux_losses or []
    self.n_skip = n_skip
    self.grad_clip = grad_clip
    self.plot_fn = plot_fn or viz.plot_sequence
    self.device = device or _auto_device()
    self.recorder = Recorder()
    self.pct_train: float = 0.0
    self._show_bar: bool = True

add_aux_loss

add_aux_loss(obj)

Append an auxiliary loss composable.

Source code in tsfast/training/learner.py
def add_aux_loss(self, obj):
    """Append an auxiliary loss composable."""
    self.aux_losses.append(obj)

add_transform

add_transform(obj)

Append a transform composable (applied train + valid).

Source code in tsfast/training/learner.py
def add_transform(self, obj):
    """Append a transform composable (applied train + valid)."""
    self.transforms.append(obj)

add_augmentation

add_augmentation(obj)

Append an augmentation composable (applied train only).

Source code in tsfast/training/learner.py
def add_augmentation(self, obj):
    """Append an augmentation composable (applied train only)."""
    self.augmentations.append(obj)

no_bar

no_bar()

Suppress tqdm progress bars (useful for Ray Tune).

Source code in tsfast/training/learner.py
@contextmanager
def no_bar(self):
    """Suppress tqdm progress bars (useful for Ray Tune)."""
    prev = self._show_bar
    self._show_bar = False
    try:
        yield
    finally:
        self._show_bar = prev

training_step

training_step(batch: tuple[Tensor, Tensor], optimizer, state=None, n_skip: int | None = None) -> tuple[float | None, object]

Single training step: apply transforms/augmentations, forward, loss, backward, step.

Source code in tsfast/training/learner.py
def training_step(
    self, batch: tuple[Tensor, Tensor], optimizer, state=None, n_skip: int | None = None
) -> tuple[float | None, object]:
    """Single training step: apply transforms/augmentations, forward, loss, backward, step."""
    xb, yb = batch
    for t in self.transforms:
        xb, yb = t(xb, yb)
    for a in self.augmentations:
        xb, yb = a(xb, yb)
    return self._forward_backward_step(xb, yb, optimizer, state, n_skip)

validate

validate(dl=None) -> tuple[float, dict[str, float]]

Run validation and compute loss + metrics on concatenated predictions.

Returns:

Type Description
tuple[float, dict[str, float]]

(val_loss, {metric_name: value})

Source code in tsfast/training/learner.py
def validate(self, dl=None) -> tuple[float, dict[str, float]]:
    """Run validation and compute loss + metrics on concatenated predictions.

    Returns:
        (val_loss, {metric_name: value})
    """
    dl = dl or self.dls.valid
    preds, targs = self.get_preds(dl=dl)

    pred_skip = preds[:, self.n_skip :] if self.n_skip > 0 else preds
    targ_skip = targs[:, self.n_skip :] if self.n_skip > 0 else targs

    val_loss = self.loss_func(pred_skip, targ_skip).item()

    metrics_dict = {}
    for m in self.metrics:
        name = getattr(m, "__name__", type(m).__name__)
        metrics_dict[name] = m(pred_skip, targ_skip).item()

    return val_loss, metrics_dict

fit

fit(n_epoch: int, lr: float | None = None, make_scheduler: Callable | None = None)

Train for n_epoch epochs.

Parameters:

Name Type Description Default
n_epoch int

number of epochs

required
lr float | None

learning rate (uses self.lr if None)

None
make_scheduler Callable | None

factory (optimizer, total_steps) -> scheduler (None = no scheduler)

None
Source code in tsfast/training/learner.py
def fit(
    self,
    n_epoch: int,
    lr: float | None = None,
    make_scheduler: Callable | None = None,
):
    """Train for n_epoch epochs.

    Args:
        n_epoch: number of epochs
        lr: learning rate (uses self.lr if None)
        make_scheduler: factory ``(optimizer, total_steps) -> scheduler`` (None = no scheduler)
    """
    lr = lr or self.lr
    self.model.to(self.device)
    optimizer = self.opt_func(self.model.parameters(), lr=lr)

    n_batches = len(self.dls.train)
    total_steps = n_epoch * n_batches
    scheduler = make_scheduler(optimizer, total_steps) if make_scheduler is not None else None

    self._setup_composables()
    try:
        step = 0
        for epoch in range(n_epoch):
            self.model.train()
            train_losses = []
            with tqdm(
                total=n_batches,
                desc=f"Epoch {epoch + 1}/{n_epoch}",
                disable=not self._show_bar,
                mininterval=0.5,
            ) as pbar:
                for batch in self.dls.train:
                    train_losses.extend(self._train_one_batch(batch, optimizer, step, total_steps))
                    if scheduler is not None:
                        scheduler.step()
                    step += 1
                    pbar.update(1)

                train_loss = sum(train_losses) / max(1, len(train_losses))

                # Validate
                val_loss, metrics_dict = self.validate()

                # Record
                row = [train_loss, val_loss] + [metrics_dict[k] for k in sorted(metrics_dict)]
                self.recorder.append(row)
                self._log_epoch(epoch, train_loss, val_loss, metrics_dict, pbar)
    finally:
        self._teardown_composables()

fit_flat_cos

fit_flat_cos(n_epoch: int, lr: float | None = None, pct_start: float = 0.75)

Convenience: flat LR then cosine decay.

Source code in tsfast/training/learner.py
def fit_flat_cos(self, n_epoch: int, lr: float | None = None, pct_start: float = 0.75):
    """Convenience: flat LR then cosine decay."""
    self.fit(
        n_epoch,
        lr=lr,
        make_scheduler=lambda opt, steps: LambdaLR(opt, lambda s: sched_flat_cos(s / steps, pct_start)),
    )

get_preds

get_preds(ds_idx: int = 1, dl=None, with_inputs: bool = False)

Batch-concatenated predictions and targets.

Parameters:

Name Type Description Default
ds_idx int

DataLoader index (0=train, 1=valid)

1
dl

explicit DataLoader (overrides ds_idx)

None
with_inputs bool

if True, also return concatenated inputs

False
Source code in tsfast/training/learner.py
def get_preds(self, ds_idx: int = 1, dl=None, with_inputs: bool = False):
    """Batch-concatenated predictions and targets.

    Args:
        ds_idx: DataLoader index (0=train, 1=valid)
        dl: explicit DataLoader (overrides ds_idx)
        with_inputs: if True, also return concatenated inputs
    """
    dl = dl or self._get_dl(ds_idx)
    # Only move when truly needed — a redundant .to() triggers
    # nn.RNN flatten_parameters which reallocates weight memory and
    # would invalidate any captured CUDA graph.
    p = next(self.model.parameters())
    if p.device.type != self.device.type or (p.device.index or 0) != (self.device.index or 0):
        self.model.to(self.device)
    self.model.eval()
    all_preds, all_targs, all_inputs = [], [], []

    with torch.no_grad():
        for batch in dl:
            xb, yb = self._to_device(batch)
            for t in self.transforms:
                xb, yb = t(xb, yb)

            result = self.model(xb)
            pred = result[0] if isinstance(result, tuple) else result

            all_preds.append(pred.cpu())
            all_targs.append(yb.cpu())
            if with_inputs:
                all_inputs.append(xb.cpu())

    preds = torch.cat(all_preds, dim=0)
    targs = torch.cat(all_targs, dim=0)
    if with_inputs:
        return preds, targs, torch.cat(all_inputs, dim=0)
    return preds, targs

get_worst

get_worst(max_n: int = 4, ds_idx: int = 1) -> tuple[Tensor, Tensor, Tensor]

Inputs, targets, and predictions for the samples with highest loss.

Returns:

Type Description
tuple[Tensor, Tensor, Tensor]

(inputs, targets, predictions) sliced to the max_n worst samples

Source code in tsfast/training/learner.py
def get_worst(self, max_n: int = 4, ds_idx: int = 1) -> tuple[Tensor, Tensor, Tensor]:
    """Inputs, targets, and predictions for the samples with highest loss.

    Returns:
        (inputs, targets, predictions) sliced to the ``max_n`` worst samples
    """
    preds, targs, inputs = self.get_preds(ds_idx=ds_idx, with_inputs=True)
    if hasattr(self.loss_func, "reduction"):
        orig = self.loss_func.reduction
        self.loss_func.reduction = "none"
        raw = self.loss_func(preds, targs)
        self.loss_func.reduction = orig
        per_sample = raw.reshape(len(preds), -1).mean(dim=1)
    else:
        per_sample = torch.tensor(
            [self.loss_func(preds[i : i + 1], targs[i : i + 1]).item() for i in range(len(preds))]
        )
    idxs = per_sample.argsort(descending=True)[:max_n]
    return inputs[idxs], targs[idxs], preds[idxs]

show_batch

show_batch(max_n: int = 4, dl=None)

Plot a batch of input/target pairs.

Source code in tsfast/training/learner.py
def show_batch(self, max_n: int = 4, dl=None):
    """Plot a batch of input/target pairs."""
    dl = dl or self.dls.valid
    batch = next(iter(dl))
    xb, yb = self._to_device(batch)
    for t in self.transforms:
        xb, yb = t(xb, yb)

    n = min(xb.shape[0], max_n)
    samples = [(xb[i].cpu(), yb[i].cpu()) for i in range(n)]
    viz.layout_samples(n, yb.shape[-1], samples, self.plot_fn, signal_names=get_signal_names(dl))

show_results

show_results(max_n: int = 4, ds_idx: int = 1)

Plot predictions vs targets.

Source code in tsfast/training/learner.py
def show_results(self, max_n: int = 4, ds_idx: int = 1):
    """Plot predictions vs targets."""
    dl = self._get_dl(ds_idx)
    self.model.to(self.device)
    self.model.eval()

    batch = next(iter(dl))
    xb, yb = self._to_device(batch)
    for t in self.transforms:
        xb, yb = t(xb, yb)

    with torch.no_grad():
        result = self.model(xb)
        pred = result[0] if isinstance(result, tuple) else result

    n = min(xb.shape[0], max_n)
    samples = [(xb[i].cpu(), yb[i].cpu()) for i in range(n)]
    outs = [(pred[i].cpu(),) for i in range(n)]
    viz.layout_samples(n, yb.shape[-1], samples, self.plot_fn, outs, signal_names=get_signal_names(dl))

show_worst

show_worst(max_n: int = 4, ds_idx: int = 1)

Plot samples with highest per-sample loss.

Source code in tsfast/training/learner.py
def show_worst(self, max_n: int = 4, ds_idx: int = 1):
    """Plot samples with highest per-sample loss."""
    inputs, targs, preds = self.get_worst(max_n=max_n, ds_idx=ds_idx)
    dl = self._get_dl(ds_idx)
    samples = [(inputs[i], targs[i]) for i in range(len(inputs))]
    outs = [(preds[i],) for i in range(len(preds))]
    viz.layout_samples(len(inputs), targs.shape[-1], samples, self.plot_fn, outs, signal_names=get_signal_names(dl))

TbpttLearner

TbpttLearner(*args, sub_seq_len: int, **kwargs)

Bases: Learner

Learner with truncated backpropagation through time (TBPTT).

Full sequences are loaded from the DataLoader, then split into sub-sequences of sub_seq_len. Hidden state is carried across sub-sequences within a batch but reset between batches.

Parameters:

Name Type Description Default
sub_seq_len int

length of each sub-sequence chunk

required
Source code in tsfast/training/learner.py
def __init__(self, *args, sub_seq_len: int, **kwargs):
    super().__init__(*args, **kwargs)
    self.sub_seq_len = sub_seq_len

CudaGraphTbpttLearner

CudaGraphTbpttLearner(*args, **kwargs)

Bases: TbpttLearner

TbpttLearner accelerated with CUDA Graphs.

Captures the forward + backward pass for a single TBPTT chunk into a CUDA graph and replays it, eliminating per-kernel CPU launch overhead. The optimizer step runs outside the graph so that LR schedulers work normally.

When n_skip > 0, a second graph is captured for the first chunk (which has different loss tensor shapes due to skip-slicing).

Constraints
  • Requires CUDA device
  • win_sz % sub_seq_len == 0 (all chunks must have equal shape)
  • Model must use return_state=True
  • Loss function must have static tensor shapes (use "nanmean" reduction, not ignore_nan)
Source code in tsfast/training/learner.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._graph: torch.cuda.CUDAGraph | None = None
    self._graph_skip: torch.cuda.CUDAGraph | None = None
    self._s_xb: Tensor | None = None
    self._s_yb: Tensor | None = None
    self._s_state: list | None = None
    self._s_new_state: list | None = None
    self._s_new_state_skip: list | None = None
    self._s_loss: Tensor | None = None
    self._s_loss_skip: Tensor | None = None