Skip to content

Physics-Informed RNN

Physics-Informed RNN models with dual encoder architecture.

PIRNN

PIRNN(n_u: int, n_y: int, init_sz: int, n_y_supervised: int | None = None, n_x: int = 0, hidden_size: int = 100, rnn_layer: int = 1, state_encoder_hidden: int = 64, linear_layer: int = 1, final_layer: int = 0, init_diag_only: bool = False, default_encoder_mode: str = 'sequence', p_state_encoder: float = 0.0, init_sz_range: tuple[int, int] | None = None, **kwargs)

Bases: Module

Physics-Informed RNN with dual encoders: Sequence and State.

Uses a diagnosis RNN (sequence encoder) to estimate initial hidden state from an initialization window, then a prognosis RNN to predict forward. Alternatively, an MLP state encoder maps a single physical state to hidden state for faster initialization.

Parameters:

Name Type Description Default
n_u int

Number of inputs.

required
n_y int

Number of outputs (total: supervised + auxiliary).

required
init_sz int

Initialization sequence length.

required
n_y_supervised int | None

Number of supervised outputs (in dataset). Defaults to n_y.

None
n_x int

Number of extra states.

0
hidden_size int

Hidden state size.

100
rnn_layer int

Number of RNN layers.

1
state_encoder_hidden int

Hidden size for state encoder MLP.

64
linear_layer int

Linear layers in diagnosis RNN.

1
final_layer int

Final layer complexity.

0
init_diag_only bool

Limit diagnosis to init_sz.

False
default_encoder_mode str

Default encoder mode during inference.

'sequence'
p_state_encoder float

Probability of using the state encoder per training batch. When > 0, randomly alternates between sequence and state encoder during training (like nn.Dropout). Has no effect during inference.

0.0
init_sz_range tuple[int, int] | None

If set, randomize init_sz uniformly within (min, max) during training. Has no effect during inference.

None
**kwargs

Additional arguments passed to RNN constructors.

{}
Source code in tsfast/pinn/pirnn.py
def __init__(
    self,
    n_u: int,
    n_y: int,
    init_sz: int,
    n_y_supervised: int | None = None,
    n_x: int = 0,
    hidden_size: int = 100,
    rnn_layer: int = 1,
    state_encoder_hidden: int = 64,
    linear_layer: int = 1,
    final_layer: int = 0,
    init_diag_only: bool = False,
    default_encoder_mode: str = "sequence",
    p_state_encoder: float = 0.0,
    init_sz_range: tuple[int, int] | None = None,
    **kwargs,
):
    super().__init__()
    n_y_supervised = n_y_supervised if n_y_supervised is not None else n_y
    self.n_u = n_u
    self.n_y = n_y
    self.n_x = n_x
    self.n_y_supervised = n_y_supervised
    self.init_sz = init_sz
    self.init_diag_only = init_diag_only
    self.hidden_size = hidden_size
    self.rnn_layer = rnn_layer
    self.default_encoder_mode = default_encoder_mode
    self.p_state_encoder = p_state_encoder
    self.init_sz_range = init_sz_range

    rnn_kwargs = dict(hidden_size=hidden_size, num_layers=rnn_layer, ret_full_hidden=True)
    rnn_kwargs = dict(rnn_kwargs, **kwargs)
    self.rnn_prognosis = RNN(n_u, **rnn_kwargs)

    # Diagnosis RNN uses supervised outputs only
    self.rnn_diagnosis = Diag_RNN(
        n_u + n_x + n_y_supervised,
        self.rnn_prognosis.state_size,
        hidden_size=hidden_size,
        rnn_layer=rnn_layer,
        linear_layer=linear_layer,
        **kwargs,
    )

    # Final layer outputs all channels (supervised + auxiliary)
    self.final = SeqLinear(hidden_size, n_y, hidden_layer=final_layer)

    # State encoder: physical state -> flat hidden state
    self.state_encoder = nn.Sequential(
        nn.Linear(n_y_supervised, state_encoder_hidden),
        nn.ReLU(),
        nn.Linear(state_encoder_hidden, self.rnn_prognosis.state_size),
    )

forward

forward(x: Tensor, init_state: list | None = None, encoder_mode: str = 'default') -> torch.Tensor

Forward pass with encoder mode auto-detection or explicit selection.

Parameters:

Name Type Description Default
x Tensor

Input tensor [batch, seq, features].

required
init_state list | None

Initial hidden state. If None, estimated by encoder.

None
encoder_mode str

Encoder selection - 'none', 'sequence', or 'state'.

'default'
Source code in tsfast/pinn/pirnn.py
def forward(
    self,
    x: torch.Tensor,
    init_state: list | None = None,
    encoder_mode: str = "default",
) -> torch.Tensor:
    """Forward pass with encoder mode auto-detection or explicit selection.

    Args:
        x: Input tensor [batch, seq, features].
        init_state: Initial hidden state. If None, estimated by encoder.
        encoder_mode: Encoder selection - 'none', 'sequence', or 'state'.
    """

    init_sz = random.randint(*self.init_sz_range) if self.training and self.init_sz_range else self.init_sz
    self._effective_init_sz = init_sz

    u = x[:, :, : self.n_u]
    # Use n_y_supervised for initialization sequence (only supervised outputs in data)
    x_init = x[:, :init_sz, : self.n_u + self.n_x + self.n_y_supervised]
    if encoder_mode == "default":
        if self.training and self.p_state_encoder > 0:
            encoder_mode = "state" if random.random() < self.p_state_encoder else "sequence"
        else:
            encoder_mode = self.default_encoder_mode

    if encoder_mode == "none":
        return self._forward_predictor(u, init_state)
    elif encoder_mode == "sequence":
        return self._forward_sequence_encoder(u[:, init_sz:], x_init, init_state)
    elif encoder_mode == "state":
        return self._forward_state_encoder(u[:, init_sz:], x_init, init_state)
    else:
        raise ValueError(f"encoder_mode must be 'none', 'sequence', or 'state', got {encoder_mode}")

encode_single_state

encode_single_state(physical_state: Tensor) -> torch.Tensor

Convert single physical state to flat hidden state vector.

Parameters:

Name Type Description Default
physical_state Tensor

Physical state [batch, n_y_supervised].

required

Returns:

Type Description
Tensor

Flat hidden state [batch, state_size].

Source code in tsfast/pinn/pirnn.py
def encode_single_state(self, physical_state: torch.Tensor) -> torch.Tensor:
    """Convert single physical state to flat hidden state vector.

    Args:
        physical_state: Physical state ``[batch, n_y_supervised]``.

    Returns:
        Flat hidden state ``[batch, state_size]``.
    """
    return self.state_encoder(physical_state)

AuxiliaryOutputLoss

AuxiliaryOutputLoss(loss_func: Callable, n_supervised: int)

Wrapper that applies loss only to supervised outputs, ignoring auxiliary outputs.

Parameters:

Name Type Description Default
loss_func Callable

Loss function to wrap.

required
n_supervised int

Number of supervised output channels.

required
Source code in tsfast/pinn/pirnn.py
def __init__(
    self,
    loss_func: Callable,
    n_supervised: int,
):
    self.loss_func = loss_func
    self.n_supervised = n_supervised

__call__

__call__(pred: Tensor, targ: Tensor) -> torch.Tensor

Apply loss only to first n_supervised channels of predictions.

Source code in tsfast/pinn/pirnn.py
def __call__(
    self,
    pred: torch.Tensor,
    targ: torch.Tensor,
) -> torch.Tensor:
    """Apply loss only to first n_supervised channels of predictions."""
    return self.loss_func(pred[..., : self.n_supervised], targ)

PIRNNLearner

PIRNNLearner(dls, init_sz: int, n_aux_outputs: int = 0, attach_output: bool = False, loss_func: Callable = nn.L1Loss(), metrics: list | None = None, opt_func: Callable = torch.optim.Adam, lr: float = 0.003, transforms: list | None = None, augmentations: list | None = None, aux_losses: list | None = None, input_norm: type | None = StandardScaler, output_norm: type | None = None, **kwargs) -> Learner

Create PIRNN learner with appropriate configuration.

Parameters:

Name Type Description Default
dls

DataLoaders.

required
init_sz int

Initialization sequence length.

required
n_aux_outputs int

Number of auxiliary outputs (not in dataset).

0
attach_output bool

Whether to attach output to input via prediction_concat.

False
loss_func Callable

Loss function.

L1Loss()
metrics list | None

Metrics.

None
opt_func Callable

Optimizer.

Adam
lr float

Learning rate.

0.003
transforms list | None

Additional transforms (train + valid).

None
augmentations list | None

Additional augmentations (train only).

None
aux_losses list | None

Additional auxiliary losses.

None
input_norm type | None

Input normalization Scaler class.

StandardScaler
output_norm type | None

Output denormalization Scaler class.

None
**kwargs

Additional arguments for PIRNN.

{}
Source code in tsfast/pinn/pirnn.py
def PIRNNLearner(
    dls,
    init_sz: int,
    n_aux_outputs: int = 0,
    attach_output: bool = False,
    loss_func: Callable = nn.L1Loss(),
    metrics: list | None = None,
    opt_func: Callable = torch.optim.Adam,
    lr: float = 3e-3,
    transforms: list | None = None,
    augmentations: list | None = None,
    aux_losses: list | None = None,
    input_norm: type | None = StandardScaler,
    output_norm: type | None = None,
    **kwargs,
) -> Learner:
    """Create PIRNN learner with appropriate configuration.

    Args:
        dls: DataLoaders.
        init_sz: Initialization sequence length.
        n_aux_outputs: Number of auxiliary outputs (not in dataset).
        attach_output: Whether to attach output to input via prediction_concat.
        loss_func: Loss function.
        metrics: Metrics.
        opt_func: Optimizer.
        lr: Learning rate.
        transforms: Additional transforms (train + valid).
        augmentations: Additional augmentations (train only).
        aux_losses: Additional auxiliary losses.
        input_norm: Input normalization Scaler class.
        output_norm: Output denormalization Scaler class.
        **kwargs: Additional arguments for PIRNN.
    """
    if metrics is None:
        metrics = [fun_rmse]
    transforms = list(transforms) if transforms else []
    augmentations = list(augmentations) if augmentations else []
    aux_losses = list(aux_losses) if aux_losses else []

    _batch = dls.one_batch()
    inp = _batch[0].shape[-1]
    out = _batch[1].shape[-1]  # Supervised outputs from dataset
    n_y_total = out + n_aux_outputs  # Total outputs (supervised + auxiliary)

    norm_u, norm_y = dls.norm_stats

    if attach_output:
        model = PIRNN(inp, n_y_total, init_sz, n_y_supervised=out, **kwargs)

        # Add prediction_concat transform if not present
        if not any(isinstance(t, prediction_concat) for t in transforms):
            transforms.insert(0, prediction_concat(t_offset=0))

        # Input will be [u, y] after prediction_concat
        combined_input_stats = norm_u + norm_y
    else:
        model = PIRNN(inp - out, n_y_total, init_sz, n_y_supervised=out, **kwargs)

        # Input is [u, y] from prediction-mode dls
        combined_input_stats = norm_u + norm_y

    # Wrap model with input normalization and optional output denormalization
    if input_norm is not None:
        in_scaler = input_norm.from_stats(combined_input_stats)
        out_scaler = output_norm.from_stats(norm_y) if output_norm is not None else None
        model = ScaledModel(model, in_scaler, out_scaler)

    # For long sequences, add truncate_sequence augmentation
    seq_len = _batch[0].shape[1]
    LENGTH_THRESHOLD = 300
    if seq_len > init_sz + LENGTH_THRESHOLD:
        if not any(isinstance(a, truncate_sequence) for a in augmentations):
            INITIAL_SEQ_LEN = 100
            augmentations.append(truncate_sequence(init_sz + INITIAL_SEQ_LEN))

    # Wrap loss and metrics to only use supervised outputs when auxiliary outputs present
    if n_aux_outputs > 0:
        loss_func = AuxiliaryOutputLoss(loss_func, out)
        metrics = [AuxiliaryOutputLoss(m, out) for m in metrics]

    return Learner(
        model,
        dls,
        loss_func=loss_func,
        metrics=metrics,
        n_skip=init_sz,
        opt_func=opt_func,
        lr=lr,
        transforms=transforms,
        augmentations=augmentations,
        aux_losses=aux_losses,
    )