PyTorch Graphs Three Ways: Data-Dependent Control Flow

March 16, 2025
python torch

Over the past few years, PyTorch went through a few iterations for turning Python code into a graph to improve performance:

  1. TorchScript can trace or parse your code to generate a TorchScript intermediate representation that works on a subset of Python. Not in active development.
  2. FX Graphs: torch.fx.symbolic_trace traces your code to produce a FX Graph we can mutate for optimizations.
  3. Torch Compile: torch.compile reads the Python bytecode to generate FX Graphs while also falling back to Python for code it does not recognize.

This blog looks at how each system handles data dependent control flow with a simple example. You can run the code in this post on Google Colab.

Simple Function

First, we define a simple function that branches based on the input's data, i.e. if the input contains all positive values:

import torch

def func(x: torch.Tensor):
    all_pos = torch.all(x >= 0)
    if all_pos:
        return x + 10
    else:
        return x - 10

here, we pass all positive and all negative tensors to show the different code paths:

x_pos = torch.asarray([1, 2, 3])
x_neg = torch.asarray([-1, -2, -3])
func(x_neg), func(x_pos)
Out[3]:
(tensor([-11, -12, -13]), tensor([11, 12, 13]))

TorchScript

TorchScript allows us to trace a function, by passing in a sample input and running jit.trace:

func_jit_trace = torch.jit.trace(func, (x_pos,))
/var/folders/9l/pvs3_wlj23z_qxd4m8dv9w300000gn/T/ipykernel_31962/766430457.py:5: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if all_pos:

We see a warning because it notices the control flow and that only one of the branches will be traced. This means that the resulting function is incorrect for a negative tensor:

func_jit_trace(x_pos), func_jit_trace(x_neg)
Out[5]:
(tensor([11, 12, 13]), tensor([9, 8, 7]))

Note, TorchScript also has a jit.script which can parse the control flow logic and give the correct result:

func_jit_script = torch.jit.script(func)
func_jit_script(x_neg), func_jit_script(x_pos)
Out[7]:
(tensor([-11, -12, -13]), tensor([11, 12, 13]))

Torch FX

When we run torch.fx.symbolic_trace on the same function, it'll throw an error because it can not convert the function into a single FX graph because of the control flow:

try:
    torch.fx.symbolic_trace(func)
except Exception as e:
    print(e)
symbolically traced variables cannot be used as inputs to control flow

A workaround is to provide a concrete_args, so the tracing only goes down one of the code paths:

fx_func = torch.fx.symbolic_trace(func, concrete_args={"x": x_pos})
/Users/thomasfan/micromamba/envs/torch/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py:906: UserWarning: Was not able to add assertion to guarantee correct input x to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with.
  warnings.warn(

But running this functino will give the incorrect results for x_neg:

fx_func(x_neg), fx_func(x_pos)
Out[10]:
(tensor([11, 12, 13]), tensor([11, 12, 13]))

torch.compile

The newer torch.compile will compile the function and gives the correct results by default:

compiled_func = torch.compile(func)
compiled_func(x_neg), compiled_func(x_pos)
Out[12]:
(tensor([-11, -12, -13]), tensor([11, 12, 13]))

Under the covers, torch.compile builds two graphs and uses a graph break to handle the control flow:

explained = torch._dynamo.explain(func)(x_pos)

print(f"""Graphs: {explained.graph_count}
Graph Breaks: {explained.graph_break_count}""")
Graphs: 2
Graph Breaks: 1

If you want the most performance, then it's best to avoid the graph breaks, using fullgraph=True:

compiled_full_bad = torch.compile(func, fullgraph=True)

But this will result in an error because the conditional requires a graph break:

try:
    compiled_full_bad(x_pos)
except Exception as e:
    print(e)
Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/var/folders/9l/pvs3_wlj23z_qxd4m8dv9w300000gn/T/ipykernel_31962/766430457.py", line 5, in func
    if all_pos:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

To work around the graph break, we can use torch.cond to build the full graph:

def func_cond(x):
    return torch.cond(
        torch.all(x >= 0), lambda x: x + 10, lambda x: x - 10, (x,)
    )

compiled_full_good = torch.compile(func_cond, fullgraph=True)
compiled_full_good(x_neg), compiled_full_good(x_pos)
Out[17]:
(tensor([-11, -12, -13]), tensor([11, 12, 13]))

We see that there are zero graph breaks by using explain:

explained_cond = torch._dynamo.explain(func_cond)(x_pos)

print(f"""Graphs: {explained_cond.graph_count}
Graph Breaks: {explained_cond.graph_break_count}""")
Graphs: 1
Graph Breaks: 0

Conclusion

TorchScript’s usability issue stems from only supporting a subset of Python, which means we likely need to update code to make it work. FX Graph has a limited symbolic tracer producing an intermediate representation that we can mutate and passed to a compiler.

Using Python’s bytecode, torch.compile still produces FX graphs, but it “just work” with any other Python code. For code torch.compile does not recognize it will fall back to the Python interpreter. This means you can torch.compile to get some initial improvements and iterate to make your code run faster by reducing the number of graph breaks.

Similar Posts

12/27/23
Python Extensions in Rust with Jupyter Notebooks
08/15/23
Quick NumPy UFuncs with Cython 3.0
05/14/23
Accessing Data from Python's DataFrame Interchange Protocol
09/12/18
Survival Regression Analysis on Customer Churn
07/31/18
Nuclei Image Segmentation Tutorial