Created
November 5, 2025 00:29
-
-
Save mlazos/713a2ffbee085ed9f06b378c3d8d9610 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch._dynamo.testing import AotEagerAndRecordGraphs | |
| import torch.fx.traceback as fx_traceback | |
| def forward(x): | |
| with fx_traceback.annotate({"pp_stage": 0}): | |
| with fx_traceback.annotate({"fdsp_bucket": 0}): | |
| sin = torch.sin(x) | |
| sub = sin - 2 | |
| with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}): | |
| mul = sub * 2 | |
| div = mul / 3 | |
| return div, mul | |
| backend = AotEagerAndRecordGraphs() | |
| opt_m = torch.compile(forward, backend=backend, fullgraph=True) | |
| x = torch.randn(10, requires_grad=True) | |
| opt_m(x)[1].sum().backward() |
and if I run your gist with TORCH_LOGS="aot_graphs", I do see the same annotation
P2023694204
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is how I get the joint graph for this forward func:
The graph I got is below. The annotation looks correct this way? I'm not sure how different this is with your torch.compile
AotEagerAndRecordGraphsbackend. Theaddnode doesn't have annotation because it's summing the gradients. Sometimes the "adding gradients" nodes are annotated, sometimes not. Unfortunately for now there's not a very good way to control whether the "gradient summing nodes" should be annotated or not.