Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Created February 18, 2026 07:22
Show Gist options
  • Select an option

  • Save msaroufim/5ca55e4ed153a2418bdcdfa0e0af4bcd to your computer and use it in GitHub Desktop.

Select an option

Save msaroufim/5ca55e4ed153a2418bdcdfa0e0af4bcd to your computer and use it in GitHub Desktop.
"""
Benchmark: hl.dot_scaled vs nvfp4_matmul vs torch._scaled_mm vs torch.matmul
=============================================================================
Compares:
1. torch.matmul (fp16) — cuBLAS baseline
2. torch._scaled_mm (e4m3) — PyTorch native FP8 scaled matmul
3. hl.dot_scaled (fp16) — hardware tl.dot_scaled, fp16 format
4. hl.dot_scaled (e4m3) — hardware tl.dot_scaled, FP8 format
5. nvfp4_matmul (sw dequant) — existing helion example, software FP4
Run on B200 (SM 10.0+):
python benchmarks/bench_dot_scaled.py
"""
from __future__ import annotations
import functools
import json
import sys
import torch
from torch import Tensor
import helion
from helion.autotuner.benchmarking import compute_repeat, interleaved_bench
import helion.language as hl
SCALE_FACTOR = 32 # uint8 e8m0 scale: 1 scale per 32 elements along K
DEVICE = "cuda"
BLOCK_M = 128
BLOCK_N = 64
BLOCK_K = 512 # must be multiple of SCALE_FACTOR
# ---------------------------------------------------------------------------
# dot_scaled kernels with proper K tiling
# ---------------------------------------------------------------------------
def make_dot_scaled_fp16_kernel(M: int, K: int, N: int):
@helion.kernel(
config=helion.Config(block_sizes=[BLOCK_M, BLOCK_N, BLOCK_K], loop_orders=[[1, 0]]),
static_shapes=True,
)
def kernel(
x: torch.Tensor,
x_scale: torch.Tensor,
y: torch.Tensor,
y_scale: torch.Tensor,
) -> torch.Tensor:
m, k = x.size()
_, n = y.size()
out = torch.empty([m, n], dtype=torch.float32, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
sk_begin = tile_k.begin // 32
sk_end = sk_begin + tile_k.block_size // 32
acc = hl.dot_scaled(
x[tile_m, tile_k],
x_scale[tile_m, sk_begin:sk_end],
"fp16",
y[tile_k, tile_n],
y_scale[tile_n, sk_begin:sk_end],
"fp16",
acc=acc,
)
out[tile_m, tile_n] = acc
return out
return kernel
def make_dot_scaled_e4m3_kernel(M: int, K: int, N: int):
@helion.kernel(
config=helion.Config(block_sizes=[BLOCK_M, BLOCK_N, BLOCK_K], loop_orders=[[1, 0]]),
static_shapes=True,
)
def kernel(
x: torch.Tensor,
x_scale: torch.Tensor,
y: torch.Tensor,
y_scale: torch.Tensor,
) -> torch.Tensor:
m, k = x.size()
_, n = y.size()
out = torch.empty([m, n], dtype=torch.float32, device=x.device)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
sk_begin = tile_k.begin // 32
sk_end = sk_begin + tile_k.block_size // 32
acc = hl.dot_scaled(
x[tile_m, tile_k],
x_scale[tile_m, sk_begin:sk_end],
"e4m3",
y[tile_k, tile_n],
y_scale[tile_n, sk_begin:sk_end],
"e4m3",
acc=acc,
)
out[tile_m, tile_n] = acc
return out
return kernel
# ---------------------------------------------------------------------------
# nvfp4_matmul with FIXED config (bypasses broken autotuner)
# ---------------------------------------------------------------------------
# Copied from examples/nvfp4_gemm.py but with a fixed config to avoid
# autotuner hitting Triton compiler bugs.
NVFP4_BLOCK_MN = 64
NVFP4_BLOCK_K_PACKED = 64 # K_packed block size (unpacked = 128)
@helion.kernel(
config=helion.Config(block_sizes=[NVFP4_BLOCK_MN, NVFP4_BLOCK_MN, NVFP4_BLOCK_K_PACKED]),
static_shapes=False,
)
def nvfp4_matmul_fixed(A: Tensor, B_packed: Tensor) -> Tensor:
M, K = A.shape
_, N = B_packed.shape
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
block_size_k_packed = hl.register_block_size(K // 2)
for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
a_tile_begin = tile_k_packed.begin * 2
a_tile_len = block_size_k_packed * 2
a_tile = A[tile_m, a_tile_begin : (a_tile_begin + a_tile_len)].to(
torch.float32
)
b_tile = B_packed[tile_k_packed, tile_n]
b_lo = b_tile & 0xF
b_hi = (b_tile >> 4) & 0xF
sign_lo = ((b_lo >> 3) & 1).to(torch.float32)
u_lo = (b_lo & 0x7).to(torch.float32)
sign_hi = ((b_hi >> 3) & 1).to(torch.float32)
u_hi = (b_hi & 0x7).to(torch.float32)
abs_lo = torch.where(
u_lo < 4,
u_lo * 0.5,
torch.where(u_lo < 6, u_lo - 2.0, u_lo * 2.0 - 8.0),
)
abs_hi = torch.where(
u_hi < 4,
u_hi * 0.5,
torch.where(u_hi < 6, u_hi - 2.0, u_hi * 2.0 - 8.0),
)
b_lo_f = abs_lo * (1.0 - 2.0 * sign_lo)
b_hi_f = abs_hi * (1.0 - 2.0 * sign_hi)
b_stacked = torch.stack([b_lo_f, b_hi_f], dim=1)
b_unpacked = b_stacked.reshape(
tile_k_packed.block_size * 2, tile_n.block_size
)
a_tile = a_tile.unsqueeze(2)
b_unpacked = b_unpacked.unsqueeze(0)
acc = acc + (a_tile * b_unpacked).sum(dim=1)
C[tile_m, tile_n] = acc.to(torch.bfloat16)
return C
# ---------------------------------------------------------------------------
# nvfp4 helpers (from examples/nvfp4_gemm.py)
# ---------------------------------------------------------------------------
def quantize_fp4_e2m1(x: Tensor) -> Tensor:
sign = (x < 0).to(torch.uint8)
abs_x = x.abs().clamp(max=6.0)
boundaries = torch.tensor(
[0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], device=x.device, dtype=abs_x.dtype
)
mag_idx = torch.bucketize(abs_x, boundaries).to(torch.uint8)
return mag_idx | (sign << 3)
def pack_fp4(indices: Tensor) -> Tensor:
K, N = indices.shape
assert K % 2 == 0
reshaped = indices.reshape(K // 2, 2, N).permute(1, 0, 2)
return ((reshaped[0] & 0xF) | (reshaped[1] << 4)).to(torch.int8)
# ---------------------------------------------------------------------------
# Baselines
# ---------------------------------------------------------------------------
def torch_matmul_fp16(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.matmul(x, y)
def torch_scaled_mm_e4m3(
x: torch.Tensor, y: torch.Tensor, x_scale: torch.Tensor, y_scale: torch.Tensor
) -> torch.Tensor:
return torch._scaled_mm(x, y, scale_a=x_scale, scale_b=y_scale, out_dtype=torch.float32)
# ---------------------------------------------------------------------------
# Input factories
# ---------------------------------------------------------------------------
def make_fp16_inputs(M: int, K: int, N: int):
x = torch.randn(M, K, device=DEVICE, dtype=torch.float16)
x_scale = torch.full((M, K // SCALE_FACTOR), 127, device=DEVICE, dtype=torch.uint8)
y = torch.randn(K, N, device=DEVICE, dtype=torch.float16)
y_scale = torch.full((N, K // SCALE_FACTOR), 127, device=DEVICE, dtype=torch.uint8)
return x, x_scale, y, y_scale
def make_e4m3_inputs(M: int, K: int, N: int):
x = (torch.randn(M, K, device=DEVICE) * 0.5).to(torch.float8_e4m3fn)
x_scale = torch.full((M, K // SCALE_FACTOR), 127, device=DEVICE, dtype=torch.uint8)
y = (torch.randn(K, N, device=DEVICE) * 0.5).to(torch.float8_e4m3fn)
y_scale = torch.full((N, K // SCALE_FACTOR), 127, device=DEVICE, dtype=torch.uint8)
return x, x_scale, y, y_scale
def make_scaled_mm_inputs(M: int, K: int, N: int):
x = (torch.randn(M, K, device=DEVICE) * 0.5).to(torch.float8_e4m3fn)
# torch._scaled_mm requires B in column-major layout (row-major A @ col-major B)
y = (torch.randn(N, K, device=DEVICE) * 0.5).to(torch.float8_e4m3fn).t()
x_scale = torch.ones(1, device=DEVICE, dtype=torch.float32)
y_scale = torch.ones(1, device=DEVICE, dtype=torch.float32)
return x, y, x_scale, y_scale
def make_nvfp4_inputs(M: int, K: int, N: int):
A = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE)
W = torch.randn(K, N, dtype=torch.bfloat16, device=DEVICE)
W_quantized = quantize_fp4_e2m1(W)
W_packed = pack_fp4(W_quantized)
return A, W_packed
# ---------------------------------------------------------------------------
# Benchmark harness
# ---------------------------------------------------------------------------
ALL_KEYS = [
"torch_matmul_fp16_ms",
"torch_scaled_mm_e4m3_ms",
"dot_scaled_fp16_ms",
"dot_scaled_e4m3_ms",
"nvfp4_matmul_ms",
]
def bench_one_size(M: int, K: int, N: int) -> dict:
results: dict[str, float] = {"M": M, "K": K, "N": N}
repeat = None
# --- torch.matmul fp16 baseline (sets repeat count) ---
try:
x_base = torch.randn(M, K, device=DEVICE, dtype=torch.float16)
y_base = torch.randn(K, N, device=DEVICE, dtype=torch.float16)
fn_base = functools.partial(torch_matmul_fp16, x_base, y_base)
fn_base()
repeat = compute_repeat(fn_base)
times = interleaved_bench([fn_base], repeat=repeat)
results["torch_matmul_fp16_ms"] = times[0]
except Exception as e:
print(f" torch.matmul failed at {M}x{K}x{N}: {e}", file=sys.stderr)
results["torch_matmul_fp16_ms"] = float("nan")
if repeat is None:
repeat = 20
# --- torch._scaled_mm e4m3 ---
try:
x_smm, y_smm, xs_smm, ys_smm = make_scaled_mm_inputs(M, K, N)
fn_smm = functools.partial(torch_scaled_mm_e4m3, x_smm, y_smm, xs_smm, ys_smm)
fn_smm()
times = interleaved_bench([fn_smm], repeat=repeat)
results["torch_scaled_mm_e4m3_ms"] = times[0]
except Exception as e:
print(f" torch._scaled_mm failed at {M}x{K}x{N}: {e}", file=sys.stderr)
results["torch_scaled_mm_e4m3_ms"] = float("nan")
# --- hl.dot_scaled fp16 ---
try:
x_fp16, xs_fp16, y_fp16, ys_fp16 = make_fp16_inputs(M, K, N)
kernel_fp16 = make_dot_scaled_fp16_kernel(M, K, N)
fn_ds_fp16 = functools.partial(kernel_fp16, x_fp16, xs_fp16, y_fp16, ys_fp16)
fn_ds_fp16()
times = interleaved_bench([fn_ds_fp16], repeat=repeat)
results["dot_scaled_fp16_ms"] = times[0]
except Exception as e:
print(f" dot_scaled fp16 failed at {M}x{K}x{N}: {e}", file=sys.stderr)
results["dot_scaled_fp16_ms"] = float("nan")
# --- hl.dot_scaled e4m3 ---
try:
x_e4, xs_e4, y_e4, ys_e4 = make_e4m3_inputs(M, K, N)
kernel_e4m3 = make_dot_scaled_e4m3_kernel(M, K, N)
fn_ds_e4 = functools.partial(kernel_e4m3, x_e4, xs_e4, y_e4, ys_e4)
fn_ds_e4()
times = interleaved_bench([fn_ds_e4], repeat=repeat)
results["dot_scaled_e4m3_ms"] = times[0]
except Exception as e:
print(f" dot_scaled e4m3 failed at {M}x{K}x{N}: {e}", file=sys.stderr)
results["dot_scaled_e4m3_ms"] = float("nan")
# --- nvfp4_matmul (software FP4 dequant, fixed config) ---
try:
A_nv, W_packed_nv = make_nvfp4_inputs(M, K, N)
fn_nvfp4 = functools.partial(nvfp4_matmul_fixed, A_nv, W_packed_nv)
fn_nvfp4() # warmup
times = interleaved_bench([fn_nvfp4], repeat=repeat)
results["nvfp4_matmul_ms"] = times[0]
except Exception as e:
print(f" nvfp4_matmul failed at {M}x{K}x{N}: {e}", file=sys.stderr)
results["nvfp4_matmul_ms"] = float("nan")
return results
def compute_tflops(M: int, K: int, N: int, ms: float) -> float:
if ms <= 0 or ms != ms:
return 0.0
return 2.0 * M * K * N / (ms * 1e-3) / 1e12
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
print("=" * 80)
print("Benchmark: hl.dot_scaled vs nvfp4_matmul vs torch._scaled_mm vs torch.matmul")
print(f"dot_scaled config: BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}, BLOCK_K={BLOCK_K}")
print(f"nvfp4 config: BLOCK_MN={NVFP4_BLOCK_MN}, BLOCK_K_PACKED={NVFP4_BLOCK_K_PACKED}")
print("=" * 80)
sizes = [2**p for p in range(8, 15)] # 256..16384
all_results = []
for sz in sizes:
M, K, N = sz, sz, sz
min_dim = max(BLOCK_M, BLOCK_N, BLOCK_K)
if sz < min_dim:
print(f"\n--- M={M}, K={K}, N={N} --- SKIPPED (size < {min_dim})")
continue
print(f"\n--- M={M}, K={K}, N={N} ---")
res = bench_one_size(M, K, N)
all_results.append(res)
for key in ALL_KEYS:
ms = res.get(key, float("nan"))
tflops = compute_tflops(M, K, N, ms)
print(f" {key:30s}: {ms:8.4f} ms ({tflops:7.2f} TFLOP/s)")
out_path = "/tmp/dot_scaled_bench_results.json"
with open(out_path, "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nResults saved to {out_path}")
try:
generate_charts(all_results)
except Exception as e:
print(f"Could not generate charts: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
def generate_charts(results: list[dict]):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
sizes = [r["M"] for r in results]
labels = [str(s) for s in sizes]
torch_fp16 = [r.get("torch_matmul_fp16_ms", float("nan")) for r in results]
torch_smm = [r.get("torch_scaled_mm_e4m3_ms", float("nan")) for r in results]
ds_fp16 = [r.get("dot_scaled_fp16_ms", float("nan")) for r in results]
ds_e4m3 = [r.get("dot_scaled_e4m3_ms", float("nan")) for r in results]
nvfp4 = [r.get("nvfp4_matmul_ms", float("nan")) for r in results]
def tflops_list(key):
return [compute_tflops(r["M"], r["K"], r["N"], r.get(key, float("nan"))) for r in results]
torch_fp16_tf = tflops_list("torch_matmul_fp16_ms")
torch_smm_tf = tflops_list("torch_scaled_mm_e4m3_ms")
ds_fp16_tf = tflops_list("dot_scaled_fp16_ms")
ds_e4m3_tf = tflops_list("dot_scaled_e4m3_ms")
nvfp4_tf = tflops_list("nvfp4_matmul_ms")
x = np.arange(len(labels))
width = 0.16
# Chart 1: Latency (log scale) — skip fp16 dot_scaled (Triton bug)
fig, ax = plt.subplots(figsize=(14, 6))
ax.bar(x - 1.5 * width, torch_fp16, width, label="torch.matmul (fp16)", color="#FF9800")
ax.bar(x - 0.5 * width, torch_smm, width, label="torch._scaled_mm (e4m3)", color="#9C27B0")
ax.bar(x + 0.5 * width, ds_e4m3, width, label="hl.dot_scaled (e4m3)", color="#4CAF50")
ax.bar(x + 1.5 * width, nvfp4, width, label="nvfp4_matmul (sw dequant)", color="#F44336")
ax.set_xlabel("Matrix Size (M=K=N)", fontsize=12)
ax.set_ylabel("Latency (ms)", fontsize=12)
ax.set_title("GEMM Latency on B200 (lower is better)", fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend(fontsize=9)
ax.set_yscale("log")
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig("/tmp/dot_scaled_latency.png", dpi=150)
print("Saved /tmp/dot_scaled_latency.png")
plt.close()
# Chart 2: Throughput — skip fp16 dot_scaled (Triton bug)
fig2, ax2 = plt.subplots(figsize=(14, 6))
ax2.plot(labels, torch_fp16_tf, "^-", label="torch.matmul (fp16)", color="#FF9800", linewidth=2, markersize=8)
ax2.plot(labels, torch_smm_tf, "v-", label="torch._scaled_mm (e4m3)", color="#9C27B0", linewidth=2, markersize=8)
ax2.plot(labels, ds_e4m3_tf, "s-", label="hl.dot_scaled (e4m3)", color="#4CAF50", linewidth=2, markersize=8)
ax2.plot(labels, nvfp4_tf, "D-", label="nvfp4_matmul (sw dequant)", color="#F44336", linewidth=2, markersize=8)
ax2.set_xlabel("Matrix Size (M=K=N)", fontsize=12)
ax2.set_ylabel("Throughput (TFLOP/s)", fontsize=12)
ax2.set_title("GEMM Throughput on B200 (higher is better)", fontsize=14)
ax2.legend(fontsize=9)
ax2.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("/tmp/dot_scaled_throughput.png", dpi=150)
print("Saved /tmp/dot_scaled_throughput.png")
plt.close()
# Chart 3: Speedup over nvfp4_matmul
fig3, ax3 = plt.subplots(figsize=(14, 6))
def speedup(baseline, target):
return [b / t if t > 0 and t == t and b > 0 and b == b else 0 for t, b in zip(target, baseline)]
sp_torch = speedup(nvfp4, torch_fp16)
sp_smm = speedup(nvfp4, torch_smm)
sp_ds_e4m3 = speedup(nvfp4, ds_e4m3)
w2 = 0.25
ax3.bar(x - w2, sp_torch, w2, label="torch.matmul (fp16)", color="#FF9800")
ax3.bar(x, sp_smm, w2, label="torch._scaled_mm (e4m3)", color="#9C27B0")
ax3.bar(x + w2, sp_ds_e4m3, w2, label="hl.dot_scaled (e4m3)", color="#4CAF50")
ax3.axhline(y=1.0, color="red", linestyle="--", alpha=0.7, label="nvfp4_matmul parity")
ax3.set_xlabel("Matrix Size (M=K=N)", fontsize=12)
ax3.set_ylabel("Speedup over nvfp4_matmul (higher is better)", fontsize=12)
ax3.set_title("Speedup over nvfp4_matmul (software FP4 dequant) on B200", fontsize=14)
ax3.set_xticks(x)
ax3.set_xticklabels(labels)
ax3.legend(fontsize=9)
ax3.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig("/tmp/dot_scaled_speedup.png", dpi=150)
print("Saved /tmp/dot_scaled_speedup.png")
plt.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment