Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Created September 18, 2025 16:48
Show Gist options
  • Select an option

  • Save justinchuby/1161f8376c137e88bf7f0a44cc3bea0a to your computer and use it in GitHub Desktop.

Select an option

Save justinchuby/1161f8376c137e88bf7f0a44cc3bea0a to your computer and use it in GitHub Desktop.
import torch
import onnx_ir as ir
class ControlFlowModel(torch.nn.Module):
def forward(self, x):
def times_2(x):
return x * 2
def neg(x):
return -x
return torch.cond(x.sum() > 0, times_2, neg, (x,))
onnx_program = torch.onnx.export(ControlFlowModel(), (torch.tensor([0.0, 1.0]),), dynamo=True)
print(ir.to_onnx_text(onnx_program.model))
print(onnx_program.exported_program)
onnx_program.save("control_flow_model.onnx")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment