Skip to content

Instantly share code, notes, and snippets.

@knwng
Created July 23, 2025 19:31
Show Gist options
  • Select an option

  • Save knwng/ab85bc87555683ad4597744a6329febe to your computer and use it in GitHub Desktop.

Select an option

Save knwng/ab85bc87555683ad4597744a6329febe to your computer and use it in GitHub Desktop.
A script to parse the perf data generated from OpenAI's open-sourced MoE Triton kernel and draw roofline graphs in one graph. Usage: python3 parse_moe_kernel_perf.py -m log_metadata.json -o output.png -d fp8mx4
from pathlib import Path
import matplotlib.pyplot as plt
import triton.profiler as proton
from triton.profiler import viewer
import torch
from triton_kernels.target_info import get_cdna_version
from dataclasses import dataclass
import argparse
import json
@dataclass
class PerfData:
time: float
flops: float
bytes: float
bitwidth: int
device_type: str
device_info: dict
@property
def tflops(self):
return self.flops / self.time * 1e-3
@property
def tbps(self):
return self.bytes / self.time * 1e-3
@property
def opint(self):
# operational intensity
assert self.bytes > 0
return self.flops / self.bytes
@property
def max_tbps(self):
return proton.specs.max_bps(self.device_type, self.device_info["arch"], self.device_info["bus_width"],
self.device_info["memory_clock_rate"]) * 1e-12
@property
def max_tflops(self):
return proton.specs.max_flops(self.device_type, self.device_info["arch"], self.bitwidth,
self.device_info["num_sms"], self.device_info["clock_rate"]) * 1e-12
@property
def util(self) -> float:
assert self.bitwidth in (8, 16)
min_t_flop = self.flops / self.max_tflops * 1e-3
min_t_bw = self.bytes / self.max_tbps * 1e-3
return max(min_t_flop, min_t_bw) / self.time
def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name, log_root):
assert n_expts_tot % EP == 0
assert dim2 % TP == 0
dev = "cuda"
# -- benchmark --
fpath = Path(f"{log_root}/{name}/{x_dtype}-{w_dtype}-TP{TP}-EP{EP}/profiles/batch-{batch}.hatchet")
x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3:
x_dtype = torch.float8_e4m3fnuz
# -- analyze --
gf, _, _, info = viewer.read(fpath)
# Now the dataframe only contains leave nodes (i.e., kernels) that perform matmuls
matmuls = gf.filter("MATCH ('*', c) WHERE c.'name' =~ '.*matmul.*' AND c IS LEAF").dataframe
bytes = matmuls["bytes"].sum()
flops = sum(matmuls[[c for c in ["flops8", "flops16"] if c in matmuls.columns]].sum())
time = matmuls["time (ns)"].sum()
device_type = matmuls["device_type"].iloc[0]
device_id = matmuls["device_id"].iloc[0]
device_info = info[device_type][device_id]
return PerfData(time=time, flops=flops, bytes=bytes, bitwidth=x_dtype.itemsize * 8, device_type=device_type,
device_info=device_info)
def plot_roofline(ax, xs, perfs, name):
max_tbps = perfs[0].max_tbps
max_tflops = perfs[0].max_tflops
from bisect import bisect_left
opints = [p.opint for p in perfs]
knee = bisect_left(opints, max_tflops / max_tbps) - 1
x_bw, x_comp = xs[:knee], xs[knee:]
x_bw = [x_bw[0], x_comp[0]]
y_bw = [opints[0] * max_tbps, max_tflops]
y_comp = [max_tflops] * len(x_comp)
ax.plot(x_bw, y_bw, "--", label=f"{name} BW-bound ({max_tbps:.1f} TB/s)")
ax.plot(x_comp, y_comp, "--", label=f"{name} Compute-bound ({max_tflops:.0f} TFLOP/s)")
LOG_METADATA_TYPE = list[str, str] # name, log_dir
def roofline_mlp(log_metadata: LOG_METADATA_TYPE, output_name: str, batch_ranges, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP=1, EP=1, name="",
verbose=True):
from itertools import chain
batches = list(chain(*[range(*r) for r in batch_ranges]))
# collect performance data
perfs = [[] for _ in range(len(log_metadata))]
bench_case = f"{name} ({x_dtype}x{w_dtype}, TP={TP}, EP={EP})"
print(f"Benchmarking {bench_case}...")
print("===============================================================")
for batch in batches:
for i, (log_name, log_dir, _) in enumerate(log_metadata):
perfs[i] += [bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, name, log_dir)]
if verbose:
print(f"[{log_name}]Batch: {batch}; Util: {perfs[i][-1].util}; TFLOPS: {perfs[i][-1].tflops}; TBPS: {perfs[i][-1].tbps}")
print("===============================================================")
# machine limits
max_tbps = perfs[0][0].max_tbps
max_tflops = perfs[0][0].max_tflops
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
ax.set_xlabel("batch size (toks/expt)")
ax.set_ylabel("performance [TFLOP/s]")
ax.set_title(f"{bench_case} roofline")
# add a tiny margin so points are not flush with the frame
xs = [batch * n_expts_act / n_expts_tot for batch in batches]
perf = [[p.tflops for p in perfs[i]] for i in range(len(perfs))]
xmin, xmax = min(xs), max(xs)
dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0
ax.set_xlim(xmin - dx, xmax + dx)
ax.set_ylim(100, max_tflops + 500)
# plot roofline
only_one_roofline = sum([x[2] for x in log_metadata]) == 1
for i, (log_name, _, draw_roofline) in enumerate(log_metadata):
if draw_roofline:
plot_roofline(ax, xs, perfs[i], '' if only_one_roofline else log_name)
ax.scatter(xs, perf[i], marker='+', label=log_name)
# plot data
ax.legend(frameon=False, loc="lower right")
ax.grid(True, which="both", ls=":", lw=0.5)
fig.tight_layout()
fpath = Path(output_name)
plt.savefig(fpath)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--log-metadata', '-m', type=str)
parser.add_argument('--output', '-o', type=str, default='perf.png')
parser.add_argument('--dtype', '-d', type=str, choices=['fp8fp8', 'fp8mx4'], default='fp8fp8')
args = parser.parse_args()
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
batch_ranges_dense = [(1024, 32768, 1024)]
batch_ranges_moe = [(128, 512, 32), (512, 32000, 128)]
dense_dtypes = ["fp8", "fp8"]
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
with open(args.log_metadata, 'r') as f:
log_metadata = json.load(f)
if args.dtype == 'fp8fp8':
roofline_mlp(log_metadata, args.output, batch_ranges_moe, 5120, 8192, 128, 4, *dense_dtypes, TP=1, EP=1, name="llama4-maverick")
else:
assert args.dtype == 'fp8mx4'
roofline_mlp(log_metadata, args.output, batch_ranges_moe, 5120, 8192, 128, 4, *quantized_dtypes, TP=1, EP=1, name="llama4-maverick")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment