While torch.compile
is great for Just in time (JIT) compilation, it adds significant startup time during prediction time. With PyTorch 2.6, torch.export
can Ahead of Time (AOT) compile your PyTorch code and serialize it into a single zip file. When it works, torch.export
's AOT approach has faster startup times compared to torch.compile
's JIT. In this post, we learn about how to serialize and load a model using torch.export
.
Using torch.export
First, we load a ResNet152 nn.Module
model onto a GPU and data that represents typical model input:
from torchvision.models import resnet152
model = resnet152().to(torch.float32).cuda()
X = torch.randn(32, 3, 128, 128).to(torch.float32).cuda()
Next, we run export
on the model and the input, while configuring a dynamic batch dimension:
from torch.export import export, Dim
batch_dim = Dim("batch", min=1, max=512)
exported_program = export(model, args=(X,), dynamic_shapes=({0: batch_dim}))
Normally, export
will AOT compile the model to only work with the exact shape of the input. By setting dynamic_shape=({0: batch_dim})
, we configure the compiler to allow the 0th indexed batch dimension to be a value between 1 and 512.
With the exported_program
, we can run the model through a forward pass by calling .module()
to extract the module:
exported_module = exported_program.module()
X_out = exported_module(X)
Finally, we can save the exported model by calling save
with a file path:
from torch.export import save
save(exported_program, "resnet152.pt2")
Loading the model
With a serialized model, we can now run torch.export.load
to load the model and run a forward pass with the data:
from torch.export import load
model = timed(load, "resnet152.pt2")
# Duration: 1.2s
result = timed(model.module(), X)
# Duration: 0.2s
The model takes 1.2s to load and 0.2s to use the loaded model to do a forward pass. For comparison, we can compile the same model with torch.compile
's JIT:
model_opt = torch.compile(model, mode="reduce-overhead", fullgraph=True)
timed(model_opt, X)
# Duration: 21.1s
The first run takes 21.1s, because torch.compile
is compiling from scratch. When we run it again with a warm cache, it takes 8.9s. Comparing the total time for initially loading the model and running a prediction, we get these total runtimes:
torch.export |
torch.compile : warm cache |
torch.compile : cold cache |
---|---|---|
1.4s | 8.9s | 21.1s |
Using torch.export
's AOT has faster startup times compared to torch.compile
's JIT!
Conclusion
Although torch.export
is faster compared to torch.compile
, export
has one big tradeoff:
torch.export
requires full graph capture, which is more restrictive.torch.compile
has the option to fall back to a Python interpreter for untraceable code, enablingcompile
to run any arbitrary Python code.
If you can capture the full graph, then torch.export
has two major benefits:
- You can load the serialized model from a non-Python environment such as C++. To learn more about this feature see the AOTInductor documentation for torch.export.
- As shown in this post, the startup times are faster.
Overall, I find it remarkable that we have the choice between using AOT and JIT for running our PyTorch code, each with its own tradeoffs. You can learn more in torch.export's documentation.