Example 15: ONNX Export and Deployment¶
Once you have trained a model, you often want to deploy it without requiring the full PyTorch stack. ONNX (Open Neural Network Exchange) is an open format that lets you run models with lightweight runtimes like ONNX Runtime. This example exports a trained model to ONNX and verifies numerical agreement with the PyTorch version.
Prerequisites¶
This notebook builds on Examples 00-02. In particular, you should be familiar
with InferenceWrapper from Example 02 (numpy-based inference with a trained
model). Make sure ONNX Runtime is installed:
uv sync --extra dev
Setup¶
import numpy as np
from tsfast.tsdata.benchmark import create_dls_silverbox
from tsfast.models.rnn import RNNLearner
from tsfast.inference import InferenceWrapper
from tsfast.inference.onnx import export_onnx, OnnxInferenceWrapper
from tsfast.training import fun_rmse
/home/pheenix/Development/tsfast/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Train a Model¶
We train a quick LSTM on the Silverbox benchmark so we have a model to export. See Example 00 for a detailed walkthrough of this step.
bs=16-- batch sizewin_sz=500-- window length in timestepsstp_sz=10-- stride between consecutive windowshidden_size=40-- number of LSTM hidden unitsmetrics=[fun_rmse]-- track RMSE during training
dls = create_dls_silverbox(bs=16, win_sz=500, stp_sz=10)
lrn = RNNLearner(dls, rnn_type='lstm', hidden_size=40, metrics=[fun_rmse])
lrn.fit_flat_cos(n_epoch=5, lr=3e-3)
Training: 0%| | 0/1500 [00:00<?, ?it/s]
Training: 0%| | 1/1500 [00:00<14:50, 1.68it/s]
Training: 2%|▏ | 37/1500 [00:01<00:36, 39.88it/s]
Training: 5%|▍ | 71/1500 [00:01<00:27, 51.49it/s]
Training: 7%|▋ | 106/1500 [00:02<00:24, 58.08it/s]
Training: 9%|▉ | 141/1500 [00:02<00:21, 62.07it/s]
Training: 12%|█▏ | 174/1500 [00:03<00:20, 63.23it/s]
Training: 14%|█▍ | 207/1500 [00:03<00:20, 63.47it/s]
Training: 16%|█▌ | 239/1500 [00:04<00:19, 63.39it/s]
Training: 18%|█▊ | 274/1500 [00:04<00:18, 65.36it/s]
Training: 20%|██ | 300/1500 [00:05<00:18, 65.36it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 20%|██ | 307/1500 [00:05<00:18, 64.50it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 23%|██▎ | 340/1500 [00:06<00:21, 53.34it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 25%|██▍ | 370/1500 [00:06<00:20, 55.01it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 27%|██▋ | 403/1500 [00:07<00:18, 57.83it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 29%|██▉ | 433/1500 [00:07<00:18, 56.97it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 31%|███ | 464/1500 [00:08<00:17, 58.13it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 33%|███▎ | 497/1500 [00:08<00:16, 60.34it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 35%|███▌ | 528/1500 [00:09<00:16, 60.69it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 37%|███▋ | 559/1500 [00:09<00:15, 60.70it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 39%|███▉ | 590/1500 [00:10<00:14, 60.96it/s, epoch 1 | train=0.0137 | valid=0.0071 | fun_rmse=0.0113]
Training: 40%|████ | 600/1500 [00:10<00:14, 60.96it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 41%|████▏ | 621/1500 [00:10<00:14, 60.66it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 43%|████▎ | 652/1500 [00:11<00:14, 59.90it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 46%|████▌ | 683/1500 [00:11<00:13, 59.58it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 48%|████▊ | 713/1500 [00:12<00:13, 59.46it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 50%|████▉ | 744/1500 [00:12<00:12, 60.16it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 52%|█████▏ | 776/1500 [00:13<00:11, 60.95it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 54%|█████▍ | 808/1500 [00:13<00:11, 61.49it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 56%|█████▌ | 839/1500 [00:14<00:10, 61.61it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 58%|█████▊ | 870/1500 [00:14<00:10, 61.68it/s, epoch 2 | train=0.0047 | valid=0.0034 | fun_rmse=0.0095]
Training: 60%|██████ | 900/1500 [00:15<00:09, 61.68it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 60%|██████ | 901/1500 [00:15<00:09, 61.57it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 62%|██████▏ | 935/1500 [00:15<00:09, 62.65it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 64%|██████▍ | 967/1500 [00:16<00:08, 62.70it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 67%|██████▋ | 999/1500 [00:16<00:08, 62.49it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 69%|██████▊ | 1031/1500 [00:17<00:07, 62.93it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 71%|███████ | 1063/1500 [00:17<00:06, 62.67it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 73%|███████▎ | 1095/1500 [00:18<00:06, 62.41it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 75%|███████▌ | 1127/1500 [00:18<00:06, 61.69it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 77%|███████▋ | 1158/1500 [00:19<00:05, 61.68it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 79%|███████▉ | 1189/1500 [00:19<00:05, 61.44it/s, epoch 3 | train=0.0045 | valid=0.0039 | fun_rmse=0.0096]
Training: 80%|████████ | 1200/1500 [00:20<00:04, 61.44it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 81%|████████▏ | 1220/1500 [00:20<00:04, 59.75it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 83%|████████▎ | 1252/1500 [00:20<00:04, 60.53it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 86%|████████▌ | 1284/1500 [00:21<00:03, 61.18it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 88%|████████▊ | 1316/1500 [00:21<00:02, 61.90it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 90%|████████▉ | 1348/1500 [00:22<00:02, 62.11it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 92%|█████████▏| 1380/1500 [00:22<00:01, 61.45it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 94%|█████████▍| 1411/1500 [00:23<00:01, 60.27it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 96%|█████████▌| 1442/1500 [00:24<00:00, 59.48it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 98%|█████████▊| 1472/1500 [00:24<00:00, 59.07it/s, epoch 4 | train=0.0042 | valid=0.0036 | fun_rmse=0.0096]
Training: 100%|██████████| 1500/1500 [00:25<00:00, 59.07it/s, epoch 5 | train=0.0033 | valid=0.0030 | fun_rmse=0.0095]
Training: 100%|██████████| 1500/1500 [00:25<00:00, 59.80it/s, epoch 5 | train=0.0033 | valid=0.0030 | fun_rmse=0.0095]
lrn.show_results(max_n=2)
PyTorch Inference with InferenceWrapper¶
As covered in Example 02, InferenceWrapper provides numpy-in / numpy-out
inference using the PyTorch model. It handles normalization automatically:
raw numpy arrays go in, raw numpy predictions come out.
pytorch_wrapper = InferenceWrapper(lrn)
xb, yb = dls.valid.one_batch()
np_input = xb.cpu().numpy()
y_pytorch = pytorch_wrapper.inference(np_input)
print(f"Input shape: {np_input.shape}")
print(f"Output shape: {y_pytorch.shape}")
Input shape: (16, 500, 1) Output shape: (16, 500, 1)
Export to ONNX¶
export_onnx converts the trained model to ONNX format. Normalization (input
scaling and output denormalization) is baked into the ONNX graph -- the
exported model accepts raw numpy inputs and produces raw outputs, just like
InferenceWrapper.
Parameters:
lrn-- the trained Learner to exportpath-- output file path (a.onnxsuffix is added if missing)opset_version=17(default) -- ONNX operator set version. Higher versions support more operations; 17 is a safe default for most runtimes.seq_len=None(default) -- override the sequence length for the dummy input used during tracing. By default it uses the window size from the DataLoaders. The exported model accepts any sequence length at runtime thanks to dynamic axes.
onnx_path = export_onnx(lrn, '/tmp/tsfast_model.onnx')
print(f"Exported to: {onnx_path}")
Exported to: /tmp/tsfast_model.onnx
/home/pheenix/Development/tsfast/.venv/lib/python3.12/site-packages/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py:4445: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. return _generic_rnn(
Load and Run with ONNX Runtime¶
OnnxInferenceWrapper loads the exported ONNX model and provides the same
.inference() API as InferenceWrapper. Under the hood it uses ONNX Runtime,
a lightweight inference engine that does not require PyTorch.
onnx_wrapper = OnnxInferenceWrapper(onnx_path)
y_onnx = onnx_wrapper.inference(np_input)
print(f"ONNX output shape: {y_onnx.shape}")
ONNX output shape: (16, 500, 1)
Verify Numerical Agreement¶
The PyTorch and ONNX outputs should be nearly identical. Small floating-point differences are expected due to different execution backends, but the maximum absolute difference should be well below 1e-4.
max_diff = np.max(np.abs(y_pytorch - y_onnx))
mean_diff = np.mean(np.abs(y_pytorch - y_onnx))
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Mean absolute difference: {mean_diff:.2e}")
assert max_diff < 1e-4, f"Outputs differ by {max_diff}"
print("Outputs match within tolerance!")
Max absolute difference: 4.54e-07 Mean absolute difference: 1.02e-07 Outputs match within tolerance!
Limitations¶
Autoregressive models (AR_RNNLearner, AR_TCNLearner) cannot be exported
to ONNX. These models contain a sequential loop that feeds predictions back as
input at each timestep, and ONNX does not support dynamic loops of this kind.
If you try to export an autoregressive model, export_onnx will raise a
ValueError with a clear message.
For autoregressive models, use InferenceWrapper instead -- it runs the
prediction loop in Python and works with any model type.
Key Takeaways¶
export_onnxconverts trained models to ONNX format with normalization baked into the graph. The exported model accepts raw inputs and produces raw outputs.OnnxInferenceWrapperprovides the same numpy interface asInferenceWrapper, making it a drop-in replacement for deployment.- ONNX Runtime is lightweight -- deploy your models without installing PyTorch.
- Always verify numerical agreement between PyTorch and ONNX outputs to catch export issues early.
- Autoregressive models cannot be exported due to their sequential
prediction loop. Use
InferenceWrapperfor those models instead.