Last active
November 3, 2025 23:52
-
-
Save yushangdi/516e0ab4aa6940745529328f4647bed4 to your computer and use it in GitHub Desktop.
augment cudagraph trace
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import json | |
| import torch | |
| import torch.distributed as dist | |
| from cuda.bindings import runtime, driver | |
| def check_errors(result, success, err_fn): | |
| err, *out = result | |
| if err != success: | |
| raise RuntimeError(err_fn(err)) | |
| if len(out) == 0: | |
| return None | |
| elif len(out) == 1: | |
| return out[0] | |
| else: | |
| return out | |
| def check_cuda(result): | |
| return check_errors(result, runtime.cudaError_t.cudaSuccess, runtime.cudaGetErrorString) | |
| def check_cuda_driver(result): | |
| return check_errors(result, driver.CUresult.CUDA_SUCCESS, driver.cuGetErrorString) | |
| from collections import defaultdict | |
| def fn(x, group, rank): | |
| with torch._C._profiler._RecordFunctionFast("line_32_all_reduce"): | |
| dist.all_reduce(x, group=group) | |
| with torch._C._profiler._RecordFunctionFast("line_33_multiply"): | |
| y = 2 * x | |
| with torch._C._profiler._RecordFunctionFast("line_34_add_sum"): | |
| y = y + y.sum() | |
| with torch._C._profiler._RecordFunctionFast("line_35_matmul"): | |
| z = y @ x | |
| with torch._C._profiler._RecordFunctionFast("line_36_to_bfloat16"): | |
| z = z.to(torch.bfloat16) | |
| src = 0 # rank // 2 * 2 | |
| with torch._C._profiler._RecordFunctionFast("line_38_broadcast"): | |
| dist.broadcast(z, group_src=src, group=group) | |
| with torch._C._profiler._RecordFunctionFast("line_39_to_float32"): | |
| result = z.to(torch.float32) | |
| return result | |
| def repeat_fn(x, group, rank, repeats): | |
| for _ in range(repeats): | |
| x = fn(x, group, rank) | |
| return x | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| rank = int(os.environ["RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| torch.cuda.set_device(f"cuda:{local_rank}") | |
| dist.init_process_group( | |
| backend="nccl", | |
| rank=rank, | |
| world_size=world_size, | |
| # store=dist.FileStore(FILE_PATH, world_size), | |
| ) | |
| my_group = dist.group.WORLD | |
| print("CREATED GROUPS") | |
| torch.cuda.memory._set_allocator_settings("expandable_segments:True") | |
| x0 = torch.randn(1024, 1024, device="cuda") | |
| x1 = torch.randn(1024, 1024, device="cuda") | |
| s0 = torch.cuda.Stream() | |
| s1 = torch.cuda.Stream() | |
| repeat_fn(x0, my_group, rank, 4) | |
| torch.cuda.synchronize() | |
| print("RAN WARMUP") | |
| # Profile without CUDA graphs to build kernel-to-tag mapping | |
| print("Profiling without CUDA graphs...") | |
| with torch.profiler.profile( | |
| activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], | |
| record_shapes=True, | |
| ) as p_no_graph: | |
| repeat_fn(x0, my_group, rank, 1) | |
| p_no_graph.export_chrome_trace(f"trace_no_graph{rank}.json") | |
| # Build kernel-to-tag mapping by parsing the chrome trace JSON | |
| print("Building kernel-to-tag mapping from chrome trace...") | |
| with open(f"trace_no_graph{rank}.json", "r") as f: | |
| trace_data = json.load(f) | |
| kernel_to_tag = {} | |
| # Extract events from the trace | |
| events = trace_data.get("traceEvents", []) | |
| # First pass: build mapping of time ranges to line tags (CPU thread) | |
| line_events = {} # pid -> tid -> list of {name, ts, dur} | |
| for event in events: | |
| if event.get("ph") == "X" and event.get("name", "").startswith("line_"): | |
| pid = event.get("pid") | |
| tid = event.get("tid") | |
| key = (pid, tid) | |
| if key not in line_events: | |
| line_events[key] = [] | |
| line_events[key].append({ | |
| 'name': event["name"], | |
| 'ts': event["ts"], | |
| 'dur': event["dur"], | |
| 'ts_end': event["ts"] + event["dur"] | |
| }) | |
| # Second pass: map correlation IDs to line tags via cudaLaunchKernel | |
| correlation_to_line = {} | |
| for event in events: | |
| if event.get("name") == "cudaLaunchKernel" and event.get("ph") == "X": | |
| correlation = event.get("args", {}).get("correlation") | |
| if correlation is not None: | |
| pid = event.get("pid") | |
| tid = event.get("tid") | |
| ts = event.get("ts") | |
| key = (pid, tid) | |
| # Find which line event contains this launch | |
| if key in line_events: | |
| for line_event in line_events[key]: | |
| if line_event['ts'] <= ts <= line_event['ts_end']: | |
| if correlation not in correlation_to_line: | |
| correlation_to_line[correlation] = [] | |
| if line_event['name'] not in correlation_to_line[correlation]: | |
| correlation_to_line[correlation].append(line_event['name']) | |
| break | |
| # Third pass: map CUDA kernels to line tags using correlation | |
| for event in events: | |
| if event.get("cat") == "kernel" and event.get("ph") == "X": | |
| kernel_name = event.get("name", "") | |
| correlation = event.get("args", {}).get("correlation") | |
| if correlation is not None and correlation in correlation_to_line: | |
| if kernel_name not in kernel_to_tag: | |
| kernel_to_tag[kernel_name] = [] | |
| for line_tag in correlation_to_line[correlation]: | |
| if line_tag not in kernel_to_tag[kernel_name]: | |
| kernel_to_tag[kernel_name].append(line_tag) | |
| print(f"Built kernel-to-tag mapping with {len(kernel_to_tag)} kernels") | |
| for kernel, tags in sorted(kernel_to_tag.items())[:5]: # Print first 5 | |
| print(f" {kernel[:80]}... -> {tags}") | |
| g = torch.cuda.CUDAGraph(keep_graph=True) | |
| g.enable_debug_mode() | |
| with torch.cuda.graph(g): | |
| repeat_fn(x0, my_group, rank, 4) | |
| raw_pointer = g.raw_cuda_graph() | |
| cudart_cuda_graph = runtime.cudaGraph_t(init_value=raw_pointer) | |
| _, _, num_edges = check_cuda( | |
| runtime.cudaGraphGetEdges(cudart_cuda_graph, numEdges=0)) | |
| fr, to, num_edges = check_cuda( | |
| runtime.cudaGraphGetEdges(cudart_cuda_graph, numEdges=num_edges)) | |
| edges = list(zip(fr, to)) | |
| print("EDGES", len(edges)) | |
| _, _, num_edges = check_cuda( | |
| runtime.cudaGraphGetEdges(cudart_cuda_graph, numEdges=0)) | |
| print("New num edges", num_edges) | |
| g.instantiate() | |
| # g.debug_dump(f"graph{rank}.png") | |
| print("Profiling CUDA graph replay...") | |
| with torch.profiler.profile( | |
| activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], | |
| record_shapes=True, | |
| ) as p_graph: | |
| g.replay() | |
| p_graph.export_chrome_trace(f"trace_graph{rank}.json") | |
| # Augment the graph trace with source line information | |
| print("\nAugmenting trace_graph with source line information...") | |
| with open(f"trace_graph{rank}.json", "r") as f: | |
| graph_trace_data = json.load(f) | |
| graph_events = graph_trace_data.get("traceEvents", []) | |
| kernels_augmented = 0 | |
| kernels_total = 0 | |
| for event in graph_events: | |
| if event.get("cat") == "kernel" and event.get("ph") == "X": | |
| kernels_total += 1 | |
| kernel_name = event.get("name", "") | |
| if kernel_name in kernel_to_tag: | |
| # Add source line information to the args | |
| if "args" not in event: | |
| event["args"] = {} | |
| event["args"]["source_lines"] = ", ".join(kernel_to_tag[kernel_name]) | |
| kernels_augmented += 1 | |
| # Write the augmented trace back | |
| with open(f"trace_graph_augmented{rank}.json", "w") as f: | |
| json.dump(graph_trace_data, f) | |
| print(f"Augmented {kernels_augmented}/{kernels_total} kernels with source line info") | |
| print(f"Augmented trace saved to: trace_graph_augmented{rank}.json") |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
mark each line with _RecordFunctionFast (if we have the aten graph, we can also just map to each aten op directly without this)
map each a cudaLaunchKernel to aten op / record function tag
post-processing the replay() profiler result json to attach the tag