Skip to content

Instantly share code, notes, and snippets.

@yushangdi
Last active November 3, 2025 23:52
Show Gist options
  • Select an option

  • Save yushangdi/516e0ab4aa6940745529328f4647bed4 to your computer and use it in GitHub Desktop.

Select an option

Save yushangdi/516e0ab4aa6940745529328f4647bed4 to your computer and use it in GitHub Desktop.
augment cudagraph trace
import os
import json
import torch
import torch.distributed as dist
from cuda.bindings import runtime, driver
def check_errors(result, success, err_fn):
err, *out = result
if err != success:
raise RuntimeError(err_fn(err))
if len(out) == 0:
return None
elif len(out) == 1:
return out[0]
else:
return out
def check_cuda(result):
return check_errors(result, runtime.cudaError_t.cudaSuccess, runtime.cudaGetErrorString)
def check_cuda_driver(result):
return check_errors(result, driver.CUresult.CUDA_SUCCESS, driver.cuGetErrorString)
from collections import defaultdict
def fn(x, group, rank):
with torch._C._profiler._RecordFunctionFast("line_32_all_reduce"):
dist.all_reduce(x, group=group)
with torch._C._profiler._RecordFunctionFast("line_33_multiply"):
y = 2 * x
with torch._C._profiler._RecordFunctionFast("line_34_add_sum"):
y = y + y.sum()
with torch._C._profiler._RecordFunctionFast("line_35_matmul"):
z = y @ x
with torch._C._profiler._RecordFunctionFast("line_36_to_bfloat16"):
z = z.to(torch.bfloat16)
src = 0 # rank // 2 * 2
with torch._C._profiler._RecordFunctionFast("line_38_broadcast"):
dist.broadcast(z, group_src=src, group=group)
with torch._C._profiler._RecordFunctionFast("line_39_to_float32"):
result = z.to(torch.float32)
return result
def repeat_fn(x, group, rank, repeats):
for _ in range(repeats):
x = fn(x, group, rank)
return x
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(f"cuda:{local_rank}")
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size,
# store=dist.FileStore(FILE_PATH, world_size),
)
my_group = dist.group.WORLD
print("CREATED GROUPS")
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
x0 = torch.randn(1024, 1024, device="cuda")
x1 = torch.randn(1024, 1024, device="cuda")
s0 = torch.cuda.Stream()
s1 = torch.cuda.Stream()
repeat_fn(x0, my_group, rank, 4)
torch.cuda.synchronize()
print("RAN WARMUP")
# Profile without CUDA graphs to build kernel-to-tag mapping
print("Profiling without CUDA graphs...")
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as p_no_graph:
repeat_fn(x0, my_group, rank, 1)
p_no_graph.export_chrome_trace(f"trace_no_graph{rank}.json")
# Build kernel-to-tag mapping by parsing the chrome trace JSON
print("Building kernel-to-tag mapping from chrome trace...")
with open(f"trace_no_graph{rank}.json", "r") as f:
trace_data = json.load(f)
kernel_to_tag = {}
# Extract events from the trace
events = trace_data.get("traceEvents", [])
# First pass: build mapping of time ranges to line tags (CPU thread)
line_events = {} # pid -> tid -> list of {name, ts, dur}
for event in events:
if event.get("ph") == "X" and event.get("name", "").startswith("line_"):
pid = event.get("pid")
tid = event.get("tid")
key = (pid, tid)
if key not in line_events:
line_events[key] = []
line_events[key].append({
'name': event["name"],
'ts': event["ts"],
'dur': event["dur"],
'ts_end': event["ts"] + event["dur"]
})
# Second pass: map correlation IDs to line tags via cudaLaunchKernel
correlation_to_line = {}
for event in events:
if event.get("name") == "cudaLaunchKernel" and event.get("ph") == "X":
correlation = event.get("args", {}).get("correlation")
if correlation is not None:
pid = event.get("pid")
tid = event.get("tid")
ts = event.get("ts")
key = (pid, tid)
# Find which line event contains this launch
if key in line_events:
for line_event in line_events[key]:
if line_event['ts'] <= ts <= line_event['ts_end']:
if correlation not in correlation_to_line:
correlation_to_line[correlation] = []
if line_event['name'] not in correlation_to_line[correlation]:
correlation_to_line[correlation].append(line_event['name'])
break
# Third pass: map CUDA kernels to line tags using correlation
for event in events:
if event.get("cat") == "kernel" and event.get("ph") == "X":
kernel_name = event.get("name", "")
correlation = event.get("args", {}).get("correlation")
if correlation is not None and correlation in correlation_to_line:
if kernel_name not in kernel_to_tag:
kernel_to_tag[kernel_name] = []
for line_tag in correlation_to_line[correlation]:
if line_tag not in kernel_to_tag[kernel_name]:
kernel_to_tag[kernel_name].append(line_tag)
print(f"Built kernel-to-tag mapping with {len(kernel_to_tag)} kernels")
for kernel, tags in sorted(kernel_to_tag.items())[:5]: # Print first 5
print(f" {kernel[:80]}... -> {tags}")
g = torch.cuda.CUDAGraph(keep_graph=True)
g.enable_debug_mode()
with torch.cuda.graph(g):
repeat_fn(x0, my_group, rank, 4)
raw_pointer = g.raw_cuda_graph()
cudart_cuda_graph = runtime.cudaGraph_t(init_value=raw_pointer)
_, _, num_edges = check_cuda(
runtime.cudaGraphGetEdges(cudart_cuda_graph, numEdges=0))
fr, to, num_edges = check_cuda(
runtime.cudaGraphGetEdges(cudart_cuda_graph, numEdges=num_edges))
edges = list(zip(fr, to))
print("EDGES", len(edges))
_, _, num_edges = check_cuda(
runtime.cudaGraphGetEdges(cudart_cuda_graph, numEdges=0))
print("New num edges", num_edges)
g.instantiate()
# g.debug_dump(f"graph{rank}.png")
print("Profiling CUDA graph replay...")
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as p_graph:
g.replay()
p_graph.export_chrome_trace(f"trace_graph{rank}.json")
# Augment the graph trace with source line information
print("\nAugmenting trace_graph with source line information...")
with open(f"trace_graph{rank}.json", "r") as f:
graph_trace_data = json.load(f)
graph_events = graph_trace_data.get("traceEvents", [])
kernels_augmented = 0
kernels_total = 0
for event in graph_events:
if event.get("cat") == "kernel" and event.get("ph") == "X":
kernels_total += 1
kernel_name = event.get("name", "")
if kernel_name in kernel_to_tag:
# Add source line information to the args
if "args" not in event:
event["args"] = {}
event["args"]["source_lines"] = ", ".join(kernel_to_tag[kernel_name])
kernels_augmented += 1
# Write the augmented trace back
with open(f"trace_graph_augmented{rank}.json", "w") as f:
json.dump(graph_trace_data, f)
print(f"Augmented {kernels_augmented}/{kernels_total} kernels with source line info")
print(f"Augmented trace saved to: trace_graph_augmented{rank}.json")
@yushangdi
Copy link
Author

yushangdi commented Nov 3, 2025

mark each line with _RecordFunctionFast (if we have the aten graph, we can also just map to each aten op directly without this)

map each a cudaLaunchKernel to aten op / record function tag

post-processing the replay() profiler result json to attach the tag

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment