Over the past few years, PyTorch went through a few iterations for turning Python code into a graph to improve performance:
- TorchScript can trace or parse your code to generate a TorchScript intermediate representation that works on a subset of Python. Not in active development.
- FX Graphs:
torch.fx.symbolic_trace
traces your code to produce a FX Graph we can mutate for optimizations. - Torch Compile:
torch.compile
reads the Python bytecode to generate FX Graphs while also falling back to Python for code it does not recognize.
This blog looks at how each system handles data dependent control flow with a simple example. You can run the code in this post on Google Colab.
Simple Function¶
First, we define a simple function that branches based on the input's data, i.e. if the input contains all positive values:
import torch
def func(x: torch.Tensor):
all_pos = torch.all(x >= 0)
if all_pos:
return x + 10
else:
return x - 10
here, we pass all positive and all negative tensors to show the different code paths:
x_pos = torch.asarray([1, 2, 3])
x_neg = torch.asarray([-1, -2, -3])
func(x_neg), func(x_pos)
TorchScript¶
TorchScript allows us to trace a function, by passing in a sample input and running jit.trace
:
func_jit_trace = torch.jit.trace(func, (x_pos,))
We see a warning because it notices the control flow and that only one of the branches will be traced. This means that the resulting function is incorrect for a negative tensor:
func_jit_trace(x_pos), func_jit_trace(x_neg)
Note, TorchScript also has a jit.script
which can parse the control flow logic and give the correct result:
func_jit_script = torch.jit.script(func)
func_jit_script(x_neg), func_jit_script(x_pos)
Torch FX¶
When we run torch.fx.symbolic_trace
on the same function, it'll throw an error because it can not convert the function into a single FX graph because of the control flow:
try:
torch.fx.symbolic_trace(func)
except Exception as e:
print(e)
A workaround is to provide a concrete_args
, so the tracing only goes down one of the code paths:
fx_func = torch.fx.symbolic_trace(func, concrete_args={"x": x_pos})
But running this functino will give the incorrect results for x_neg
:
fx_func(x_neg), fx_func(x_pos)
torch.compile¶
The newer torch.compile
will compile the function and gives the correct results by default:
compiled_func = torch.compile(func)
compiled_func(x_neg), compiled_func(x_pos)
Under the covers, torch.compile
builds two graphs and uses a graph break to handle the control flow:
explained = torch._dynamo.explain(func)(x_pos)
print(f"""Graphs: {explained.graph_count}
Graph Breaks: {explained.graph_break_count}""")
If you want the most performance, then it's best to avoid the graph breaks, using fullgraph=True
:
compiled_full_bad = torch.compile(func, fullgraph=True)
But this will result in an error because the conditional requires a graph break:
try:
compiled_full_bad(x_pos)
except Exception as e:
print(e)
To work around the graph break, we can use torch.cond
to build the full graph:
def func_cond(x):
return torch.cond(
torch.all(x >= 0), lambda x: x + 10, lambda x: x - 10, (x,)
)
compiled_full_good = torch.compile(func_cond, fullgraph=True)
compiled_full_good(x_neg), compiled_full_good(x_pos)
We see that there are zero graph breaks by using explain
:
explained_cond = torch._dynamo.explain(func_cond)(x_pos)
print(f"""Graphs: {explained_cond.graph_count}
Graph Breaks: {explained_cond.graph_break_count}""")
Conclusion¶
TorchScript’s usability issue stems from only supporting a subset of Python, which means we likely need to update code to make it work. FX Graph has a limited symbolic tracer producing an intermediate representation that we can mutate and passed to a compiler.
Using Python’s bytecode, torch.compile
still produces FX graphs, but it “just work” with any other Python code. For code torch.compile
does not recognize it will fall back to the Python interpreter. This means you can torch.compile
to get some initial improvements and iterate to make your code run faster by reducing the number of graph breaks.