torch.export For Serializing Models and Faster Loading

April 26, 2025
python torch

While torch.compile is great for Just in time (JIT) compilation, it adds significant startup time during prediction time. With PyTorch 2.6, torch.export can Ahead of Time (AOT) compile your PyTorch code and serialize it into a single zip file. When it works, torch.export's AOT approach has faster startup times compared to torch.compile's JIT. In this post, we learn about how to serialize and load a model using torch.export.

Using torch.export

First, we load a ResNet152 nn.Module model onto a GPU and data that represents typical model input:

from torchvision.models import resnet152

model = resnet152().to(torch.float32).cuda()
X = torch.randn(32, 3, 128, 128).to(torch.float32).cuda()

Next, we run export on the model and the input, while configuring a dynamic batch dimension:

from torch.export import export, Dim

batch_dim = Dim("batch", min=1, max=512)
exported_program = export(model, args=(X,), dynamic_shapes=({0: batch_dim}))

Normally, export will AOT compile the model to only work with the exact shape of the input. By setting dynamic_shape=({0: batch_dim}), we configure the compiler to allow the 0th indexed batch dimension to be a value between 1 and 512.

With the exported_program, we can run the model through a forward pass by calling .module() to extract the module:

exported_module = exported_program.module()
X_out = exported_module(X)

Finally, we can save the exported model by calling save with a file path:

from torch.export import save

save(exported_program, "resnet152.pt2")

Loading the model

With a serialized model, we can now run torch.export.load to load the model and run a forward pass with the data:

from torch.export import load

model = timed(load, "resnet152.pt2")
# Duration: 1.2s

result = timed(model.module(), X)
# Duration: 0.2s

The model takes 1.2s to load and 0.2s to use the loaded model to do a forward pass. For comparison, we can compile the same model with torch.compile's JIT:

model_opt = torch.compile(model, mode="reduce-overhead", fullgraph=True)
timed(model_opt, X)
# Duration: 21.1s

The first run takes 21.1s, because torch.compile is compiling from scratch. When we run it again with a warm cache, it takes 8.9s. Comparing the total time for initially loading the model and running a prediction, we get these total runtimes:

torch.export torch.compile: warm cache torch.compile: cold cache
1.4s 8.9s 21.1s

Using torch.export's AOT has faster startup times compared to torch.compile's JIT!

Conclusion

Although torch.export is faster compared to torch.compile, export has one big tradeoff:

  • torch.export requires full graph capture, which is more restrictive. torch.compile has the option to fall back to a Python interpreter for untraceable code, enabling compile to run any arbitrary Python code.

If you can capture the full graph, then torch.export has two major benefits:

  • You can load the serialized model from a non-Python environment such as C++. To learn more about this feature see the AOTInductor documentation for torch.export.
  • As shown in this post, the startup times are faster.

Overall, I find it remarkable that we have the choice between using AOT and JIT for running our PyTorch code, each with its own tradeoffs. You can learn more in torch.export's documentation.

Similar Posts

05/15/25
Six Years as a scikit-learn maintainer - Feature Retrospective
04/06/25
Keep Warm with Portable torch.compile Caches
03/16/25
PyTorch Graphs Three Ways: Data-Dependent Control Flow
12/27/23
Python Extensions in Rust with Jupyter Notebooks
08/15/23
Quick NumPy UFuncs with Cython 3.0