Created
July 23, 2025 19:31
-
-
Save knwng/ab85bc87555683ad4597744a6329febe to your computer and use it in GitHub Desktop.
A script to parse the perf data generated from OpenAI's open-sourced MoE Triton kernel and draw roofline graphs in one graph. Usage: python3 parse_moe_kernel_perf.py -m log_metadata.json -o output.png -d fp8mx4
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import triton.profiler as proton | |
| from triton.profiler import viewer | |
| import torch | |
| from triton_kernels.target_info import get_cdna_version | |
| from dataclasses import dataclass | |
| import argparse | |
| import json | |
| @dataclass | |
| class PerfData: | |
| time: float | |
| flops: float | |
| bytes: float | |
| bitwidth: int | |
| device_type: str | |
| device_info: dict | |
| @property | |
| def tflops(self): | |
| return self.flops / self.time * 1e-3 | |
| @property | |
| def tbps(self): | |
| return self.bytes / self.time * 1e-3 | |
| @property | |
| def opint(self): | |
| # operational intensity | |
| assert self.bytes > 0 | |
| return self.flops / self.bytes | |
| @property | |
| def max_tbps(self): | |
| return proton.specs.max_bps(self.device_type, self.device_info["arch"], self.device_info["bus_width"], | |
| self.device_info["memory_clock_rate"]) * 1e-12 | |
| @property | |
| def max_tflops(self): | |
| return proton.specs.max_flops(self.device_type, self.device_info["arch"], self.bitwidth, | |
| self.device_info["num_sms"], self.device_info["clock_rate"]) * 1e-12 | |
| @property | |
| def util(self) -> float: | |
| assert self.bitwidth in (8, 16) | |
| min_t_flop = self.flops / self.max_tflops * 1e-3 | |
| min_t_bw = self.bytes / self.max_tbps * 1e-3 | |
| return max(min_t_flop, min_t_bw) / self.time | |
| def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name, log_root): | |
| assert n_expts_tot % EP == 0 | |
| assert dim2 % TP == 0 | |
| dev = "cuda" | |
| # -- benchmark -- | |
| fpath = Path(f"{log_root}/{name}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/profiles/batch-{batch}.hatchet") | |
| x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype] | |
| # special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz | |
| if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3: | |
| x_dtype = torch.float8_e4m3fnuz | |
| # -- analyze -- | |
| gf, _, _, info = viewer.read(fpath) | |
| # Now the dataframe only contains leave nodes (i.e., kernels) that perform matmuls | |
| matmuls = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*matmul.*' AND c IS LEAF").dataframe | |
| bytes = matmuls["bytes"].sum() | |
| flops = sum(matmuls[[c for c in ["flops8", "flops16"] if c in matmuls.columns]].sum()) | |
| time = matmuls["time (ns)"].sum() | |
| device_type = matmuls["device_type"].iloc[0] | |
| device_id = matmuls["device_id"].iloc[0] | |
| device_info = info[device_type][device_id] | |
| return PerfData(time=time, flops=flops, bytes=bytes, bitwidth=x_dtype.itemsize * 8, device_type=device_type, | |
| device_info=device_info) | |
| def plot_roofline(ax, xs, perfs, name): | |
| max_tbps = perfs[0].max_tbps | |
| max_tflops = perfs[0].max_tflops | |
| from bisect import bisect_left | |
| opints = [p.opint for p in perfs] | |
| knee = bisect_left(opints, max_tflops / max_tbps) - 1 | |
| x_bw, x_comp = xs[:knee], xs[knee:] | |
| x_bw = [x_bw[0], x_comp[0]] | |
| y_bw = [opints[0] * max_tbps, max_tflops] | |
| y_comp = [max_tflops] * len(x_comp) | |
| ax.plot(x_bw, y_bw, "--", label=f"{name} BW-bound ({max_tbps:.1f} TB/s)") | |
| ax.plot(x_comp, y_comp, "--", label=f"{name} Compute-bound ({max_tflops:.0f} TFLOP/s)") | |
| LOG_METADATA_TYPE = list[str, str] # name, log_dir | |
| def roofline_mlp(log_metadata: LOG_METADATA_TYPE, output_name: str, batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="", | |
| verbose=True): | |
| from itertools import chain | |
| batches = list(chain(*[range(*r) for r in batch_ranges])) | |
| # collect performance data | |
| perfs = [[] for _ in range(len(log_metadata))] | |
| bench_case = f"{name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})" | |
| print(f"Benchmarking {bench_case}...") | |
| print("===============================================================") | |
| for batch in batches: | |
| for i, (log_name, log_dir, _) in enumerate(log_metadata): | |
| perfs[i] += [bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name, log_dir)] | |
| if verbose: | |
| print(f"[{log_name}]Batch: {batch}; Util: {perfs[i][-1].util}; TFLOPS: {perfs[i][-1].tflops}; TBPS: {perfs[i][-1].tbps}") | |
| print("===============================================================") | |
| # machine limits | |
| max_tbps = perfs[0][0].max_tbps | |
| max_tflops = perfs[0][0].max_tflops | |
| fig, ax = plt.subplots(figsize=(7, 5), dpi=120) | |
| ax.set_xlabel("batch size (toks/expt)") | |
| ax.set_ylabel("performance [TFLOP/s]") | |
| ax.set_title(f"{bench_case} roofline") | |
| # add a tiny margin so points are not flush with the frame | |
| xs = [batch * n_expts_act / n_expts_tot for batch in batches] | |
| perf = [[p.tflops for p in perfs[i]] for i in range(len(perfs))] | |
| xmin, xmax = min(xs), max(xs) | |
| dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0 | |
| ax.set_xlim(xmin - dx, xmax + dx) | |
| ax.set_ylim(100, max_tflops + 500) | |
| # plot roofline | |
| only_one_roofline = sum([x[2] for x in log_metadata]) == 1 | |
| for i, (log_name, _, draw_roofline) in enumerate(log_metadata): | |
| if draw_roofline: | |
| plot_roofline(ax, xs, perfs[i], '' if only_one_roofline else log_name) | |
| ax.scatter(xs, perf[i], marker='+', label=log_name) | |
| # plot data | |
| ax.legend(frameon=False, loc="lower right") | |
| ax.grid(True, which="both", ls=":", lw=0.5) | |
| fig.tight_layout() | |
| fpath = Path(output_name) | |
| plt.savefig(fpath) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--log-metadata', '-m', type=str) | |
| parser.add_argument('--output', '-o', type=str, default='perf.png') | |
| parser.add_argument('--dtype', '-d', type=str, choices=['fp8fp8', 'fp8mx4'], default='fp8fp8') | |
| args = parser.parse_args() | |
| has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4 | |
| batch_ranges_dense = [(1024, 32768, 1024)] | |
| batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)] | |
| dense_dtypes = ["fp8", "fp8"] | |
| quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"] | |
| with open(args.log_metadata, 'r') as f: | |
| log_metadata = json.load(f) | |
| if args.dtype == 'fp8fp8': | |
| roofline_mlp(log_metadata, args.output, batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick") | |
| else: | |
| assert args.dtype == 'fp8mx4' | |
| roofline_mlp(log_metadata, args.output, batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment