Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created November 5, 2025 00:29
Show Gist options
  • Select an option

  • Save mlazos/713a2ffbee085ed9f06b378c3d8d9610 to your computer and use it in GitHub Desktop.

Select an option

Save mlazos/713a2ffbee085ed9f06b378c3d8d9610 to your computer and use it in GitHub Desktop.
import torch
from torch._dynamo.testing import AotEagerAndRecordGraphs
import torch.fx.traceback as fx_traceback
def forward(x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
sin = torch.sin(x)
sub = sin - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
mul = sub * 2
div = mul / 3
return div, mul
backend = AotEagerAndRecordGraphs()
opt_m = torch.compile(forward, backend=backend, fullgraph=True)
x = torch.randn(10, requires_grad=True)
opt_m(x)[1].sum().backward()
@yushangdi
Copy link

This is how I get the joint graph for this forward func:

# in pytorch/test/functorch/test_aot_joint_with_descriptors.py
    def test_return_tuple(self):
        class TupleModule(nn.Module):
            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    with fx_traceback.annotate({"fdsp_bucket": 0}):
                        sin = torch.sin(x)
                    sub = sin - 2
                with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
                    mul = sub * 2
                    div = mul / 3
                return div, mul

        model = TupleModule()
        x = torch.randn(10, requires_grad=True)
        graph_module = graph_capture(model, (x,), False)
        custom_metadata = fx_traceback._get_custom_metadata(graph_module)

        graph_module.print_readable()

The graph I got is below. The annotation looks correct this way? I'm not sure how different this is with your torch.compile AotEagerAndRecordGraphs backend. The add node doesn't have annotation because it's summing the gradients. Sometimes the "adding gradients" nodes are annotated, sometimes not. Unfortunately for now there's not a very good way to control whether the "gradient summing nodes" should be annotated or not.

class inner_f(torch.nn.Module):
    def forward(self, primals, tangents):
        primals_1: "f32[10]"; tangents_1: "f32[10]"; tangents_2: "f32[10]"; 
    
        primals_1, tangents_1, tangents_2, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
        # Annotation: {'pp_stage': 0, 'fdsp_bucket': 0} No stacktrace found for following nodes
        sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)
        
        # Annotation: {'pp_stage': 0} No stacktrace found for following nodes
        sub: "f32[10]" = torch.ops.aten.sub.Tensor(sin, 2);  sin = None
        
        # Annotation: {'cuda_stream': 2, 'fsdp_bucket': 1} No stacktrace found for following nodes
        mul: "f32[10]" = torch.ops.aten.mul.Tensor(sub, 2);  sub = None
        div: "f32[10]" = torch.ops.aten.div.Tensor(mul, 3)
        div_1: "f32[10]" = torch.ops.aten.div.Tensor(tangents_1, 3);  tangents_1 = None
        
        # No stacktrace found for following nodes
        add: "f32[10]" = torch.ops.aten.add.Tensor(tangents_2, div_1);  tangents_2 = div_1 = None
        
        # Annotation: {'cuda_stream': 2, 'fsdp_bucket': 1} No stacktrace found for following nodes
        mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(add, 2);  add = None
        
        # Annotation: {'pp_stage': 0, 'fdsp_bucket': 0} No stacktrace found for following nodes
        cos: "f32[10]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None
        mul_2: "f32[10]" = torch.ops.aten.mul.Tensor(mul_1, cos);  mul_1 = cos = None
        return pytree.tree_unflatten([div, mul, mul_2], self._out_spec)

@yushangdi
Copy link

and if I run your gist with TORCH_LOGS="aot_graphs", I do see the same annotation

P2023694204

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment