Created
February 18, 2026 07:22
-
-
Save msaroufim/5ca55e4ed153a2418bdcdfa0e0af4bcd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Benchmark: hl.dot_scaled vs nvfp4_matmul vs torch._scaled_mm vs torch.matmul | |
| ============================================================================= | |
| Compares: | |
| 1. torch.matmul (fp16) — cuBLAS baseline | |
| 2. torch._scaled_mm (e4m3) — PyTorch native FP8 scaled matmul | |
| 3. hl.dot_scaled (fp16) — hardware tl.dot_scaled, fp16 format | |
| 4. hl.dot_scaled (e4m3) — hardware tl.dot_scaled, FP8 format | |
| 5. nvfp4_matmul (sw dequant) — existing helion example, software FP4 | |
| Run on B200 (SM 10.0+): | |
| python benchmarks/bench_dot_scaled.py | |
| """ | |
| from __future__ import annotations | |
| import functools | |
| import json | |
| import sys | |
| import torch | |
| from torch import Tensor | |
| import helion | |
| from helion.autotuner.benchmarking import compute_repeat, interleaved_bench | |
| import helion.language as hl | |
| SCALE_FACTOR = 32 # uint8 e8m0 scale: 1 scale per 32 elements along K | |
| DEVICE = "cuda" | |
| BLOCK_M = 128 | |
| BLOCK_N = 64 | |
| BLOCK_K = 512 # must be multiple of SCALE_FACTOR | |
| # --------------------------------------------------------------------------- | |
| # dot_scaled kernels with proper K tiling | |
| # --------------------------------------------------------------------------- | |
| def make_dot_scaled_fp16_kernel(M: int, K: int, N: int): | |
| @helion.kernel( | |
| config=helion.Config(block_sizes=[BLOCK_M, BLOCK_N, BLOCK_K], loop_orders=[[1, 0]]), | |
| static_shapes=True, | |
| ) | |
| def kernel( | |
| x: torch.Tensor, | |
| x_scale: torch.Tensor, | |
| y: torch.Tensor, | |
| y_scale: torch.Tensor, | |
| ) -> torch.Tensor: | |
| m, k = x.size() | |
| _, n = y.size() | |
| out = torch.empty([m, n], dtype=torch.float32, device=x.device) | |
| for tile_m, tile_n in hl.tile([m, n]): | |
| acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) | |
| for tile_k in hl.tile(k): | |
| sk_begin = tile_k.begin // 32 | |
| sk_end = sk_begin + tile_k.block_size // 32 | |
| acc = hl.dot_scaled( | |
| x[tile_m, tile_k], | |
| x_scale[tile_m, sk_begin:sk_end], | |
| "fp16", | |
| y[tile_k, tile_n], | |
| y_scale[tile_n, sk_begin:sk_end], | |
| "fp16", | |
| acc=acc, | |
| ) | |
| out[tile_m, tile_n] = acc | |
| return out | |
| return kernel | |
| def make_dot_scaled_e4m3_kernel(M: int, K: int, N: int): | |
| @helion.kernel( | |
| config=helion.Config(block_sizes=[BLOCK_M, BLOCK_N, BLOCK_K], loop_orders=[[1, 0]]), | |
| static_shapes=True, | |
| ) | |
| def kernel( | |
| x: torch.Tensor, | |
| x_scale: torch.Tensor, | |
| y: torch.Tensor, | |
| y_scale: torch.Tensor, | |
| ) -> torch.Tensor: | |
| m, k = x.size() | |
| _, n = y.size() | |
| out = torch.empty([m, n], dtype=torch.float32, device=x.device) | |
| for tile_m, tile_n in hl.tile([m, n]): | |
| acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) | |
| for tile_k in hl.tile(k): | |
| sk_begin = tile_k.begin // 32 | |
| sk_end = sk_begin + tile_k.block_size // 32 | |
| acc = hl.dot_scaled( | |
| x[tile_m, tile_k], | |
| x_scale[tile_m, sk_begin:sk_end], | |
| "e4m3", | |
| y[tile_k, tile_n], | |
| y_scale[tile_n, sk_begin:sk_end], | |
| "e4m3", | |
| acc=acc, | |
| ) | |
| out[tile_m, tile_n] = acc | |
| return out | |
| return kernel | |
| # --------------------------------------------------------------------------- | |
| # nvfp4_matmul with FIXED config (bypasses broken autotuner) | |
| # --------------------------------------------------------------------------- | |
| # Copied from examples/nvfp4_gemm.py but with a fixed config to avoid | |
| # autotuner hitting Triton compiler bugs. | |
| NVFP4_BLOCK_MN = 64 | |
| NVFP4_BLOCK_K_PACKED = 64 # K_packed block size (unpacked = 128) | |
| @helion.kernel( | |
| config=helion.Config(block_sizes=[NVFP4_BLOCK_MN, NVFP4_BLOCK_MN, NVFP4_BLOCK_K_PACKED]), | |
| static_shapes=False, | |
| ) | |
| def nvfp4_matmul_fixed(A: Tensor, B_packed: Tensor) -> Tensor: | |
| M, K = A.shape | |
| _, N = B_packed.shape | |
| C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) | |
| block_size_k_packed = hl.register_block_size(K // 2) | |
| for tile_m, tile_n in hl.tile([M, N]): | |
| acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) | |
| for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed): | |
| a_tile_begin = tile_k_packed.begin * 2 | |
| a_tile_len = block_size_k_packed * 2 | |
| a_tile = A[tile_m, a_tile_begin : (a_tile_begin + a_tile_len)].to( | |
| torch.float32 | |
| ) | |
| b_tile = B_packed[tile_k_packed, tile_n] | |
| b_lo = b_tile & 0xF | |
| b_hi = (b_tile >> 4) & 0xF | |
| sign_lo = ((b_lo >> 3) & 1).to(torch.float32) | |
| u_lo = (b_lo & 0x7).to(torch.float32) | |
| sign_hi = ((b_hi >> 3) & 1).to(torch.float32) | |
| u_hi = (b_hi & 0x7).to(torch.float32) | |
| abs_lo = torch.where( | |
| u_lo < 4, | |
| u_lo * 0.5, | |
| torch.where(u_lo < 6, u_lo - 2.0, u_lo * 2.0 - 8.0), | |
| ) | |
| abs_hi = torch.where( | |
| u_hi < 4, | |
| u_hi * 0.5, | |
| torch.where(u_hi < 6, u_hi - 2.0, u_hi * 2.0 - 8.0), | |
| ) | |
| b_lo_f = abs_lo * (1.0 - 2.0 * sign_lo) | |
| b_hi_f = abs_hi * (1.0 - 2.0 * sign_hi) | |
| b_stacked = torch.stack([b_lo_f, b_hi_f], dim=1) | |
| b_unpacked = b_stacked.reshape( | |
| tile_k_packed.block_size * 2, tile_n.block_size | |
| ) | |
| a_tile = a_tile.unsqueeze(2) | |
| b_unpacked = b_unpacked.unsqueeze(0) | |
| acc = acc + (a_tile * b_unpacked).sum(dim=1) | |
| C[tile_m, tile_n] = acc.to(torch.bfloat16) | |
| return C | |
| # --------------------------------------------------------------------------- | |
| # nvfp4 helpers (from examples/nvfp4_gemm.py) | |
| # --------------------------------------------------------------------------- | |
| def quantize_fp4_e2m1(x: Tensor) -> Tensor: | |
| sign = (x < 0).to(torch.uint8) | |
| abs_x = x.abs().clamp(max=6.0) | |
| boundaries = torch.tensor( | |
| [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], device=x.device, dtype=abs_x.dtype | |
| ) | |
| mag_idx = torch.bucketize(abs_x, boundaries).to(torch.uint8) | |
| return mag_idx | (sign << 3) | |
| def pack_fp4(indices: Tensor) -> Tensor: | |
| K, N = indices.shape | |
| assert K % 2 == 0 | |
| reshaped = indices.reshape(K // 2, 2, N).permute(1, 0, 2) | |
| return ((reshaped[0] & 0xF) | (reshaped[1] << 4)).to(torch.int8) | |
| # --------------------------------------------------------------------------- | |
| # Baselines | |
| # --------------------------------------------------------------------------- | |
| def torch_matmul_fp16(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| return torch.matmul(x, y) | |
| def torch_scaled_mm_e4m3( | |
| x: torch.Tensor, y: torch.Tensor, x_scale: torch.Tensor, y_scale: torch.Tensor | |
| ) -> torch.Tensor: | |
| return torch._scaled_mm(x, y, scale_a=x_scale, scale_b=y_scale, out_dtype=torch.float32) | |
| # --------------------------------------------------------------------------- | |
| # Input factories | |
| # --------------------------------------------------------------------------- | |
| def make_fp16_inputs(M: int, K: int, N: int): | |
| x = torch.randn(M, K, device=DEVICE, dtype=torch.float16) | |
| x_scale = torch.full((M, K // SCALE_FACTOR), 127, device=DEVICE, dtype=torch.uint8) | |
| y = torch.randn(K, N, device=DEVICE, dtype=torch.float16) | |
| y_scale = torch.full((N, K // SCALE_FACTOR), 127, device=DEVICE, dtype=torch.uint8) | |
| return x, x_scale, y, y_scale | |
| def make_e4m3_inputs(M: int, K: int, N: int): | |
| x = (torch.randn(M, K, device=DEVICE) * 0.5).to(torch.float8_e4m3fn) | |
| x_scale = torch.full((M, K // SCALE_FACTOR), 127, device=DEVICE, dtype=torch.uint8) | |
| y = (torch.randn(K, N, device=DEVICE) * 0.5).to(torch.float8_e4m3fn) | |
| y_scale = torch.full((N, K // SCALE_FACTOR), 127, device=DEVICE, dtype=torch.uint8) | |
| return x, x_scale, y, y_scale | |
| def make_scaled_mm_inputs(M: int, K: int, N: int): | |
| x = (torch.randn(M, K, device=DEVICE) * 0.5).to(torch.float8_e4m3fn) | |
| # torch._scaled_mm requires B in column-major layout (row-major A @ col-major B) | |
| y = (torch.randn(N, K, device=DEVICE) * 0.5).to(torch.float8_e4m3fn).t() | |
| x_scale = torch.ones(1, device=DEVICE, dtype=torch.float32) | |
| y_scale = torch.ones(1, device=DEVICE, dtype=torch.float32) | |
| return x, y, x_scale, y_scale | |
| def make_nvfp4_inputs(M: int, K: int, N: int): | |
| A = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) | |
| W = torch.randn(K, N, dtype=torch.bfloat16, device=DEVICE) | |
| W_quantized = quantize_fp4_e2m1(W) | |
| W_packed = pack_fp4(W_quantized) | |
| return A, W_packed | |
| # --------------------------------------------------------------------------- | |
| # Benchmark harness | |
| # --------------------------------------------------------------------------- | |
| ALL_KEYS = [ | |
| "torch_matmul_fp16_ms", | |
| "torch_scaled_mm_e4m3_ms", | |
| "dot_scaled_fp16_ms", | |
| "dot_scaled_e4m3_ms", | |
| "nvfp4_matmul_ms", | |
| ] | |
| def bench_one_size(M: int, K: int, N: int) -> dict: | |
| results: dict[str, float] = {"M": M, "K": K, "N": N} | |
| repeat = None | |
| # --- torch.matmul fp16 baseline (sets repeat count) --- | |
| try: | |
| x_base = torch.randn(M, K, device=DEVICE, dtype=torch.float16) | |
| y_base = torch.randn(K, N, device=DEVICE, dtype=torch.float16) | |
| fn_base = functools.partial(torch_matmul_fp16, x_base, y_base) | |
| fn_base() | |
| repeat = compute_repeat(fn_base) | |
| times = interleaved_bench([fn_base], repeat=repeat) | |
| results["torch_matmul_fp16_ms"] = times[0] | |
| except Exception as e: | |
| print(f" torch.matmul failed at {M}x{K}x{N}: {e}", file=sys.stderr) | |
| results["torch_matmul_fp16_ms"] = float("nan") | |
| if repeat is None: | |
| repeat = 20 | |
| # --- torch._scaled_mm e4m3 --- | |
| try: | |
| x_smm, y_smm, xs_smm, ys_smm = make_scaled_mm_inputs(M, K, N) | |
| fn_smm = functools.partial(torch_scaled_mm_e4m3, x_smm, y_smm, xs_smm, ys_smm) | |
| fn_smm() | |
| times = interleaved_bench([fn_smm], repeat=repeat) | |
| results["torch_scaled_mm_e4m3_ms"] = times[0] | |
| except Exception as e: | |
| print(f" torch._scaled_mm failed at {M}x{K}x{N}: {e}", file=sys.stderr) | |
| results["torch_scaled_mm_e4m3_ms"] = float("nan") | |
| # --- hl.dot_scaled fp16 --- | |
| try: | |
| x_fp16, xs_fp16, y_fp16, ys_fp16 = make_fp16_inputs(M, K, N) | |
| kernel_fp16 = make_dot_scaled_fp16_kernel(M, K, N) | |
| fn_ds_fp16 = functools.partial(kernel_fp16, x_fp16, xs_fp16, y_fp16, ys_fp16) | |
| fn_ds_fp16() | |
| times = interleaved_bench([fn_ds_fp16], repeat=repeat) | |
| results["dot_scaled_fp16_ms"] = times[0] | |
| except Exception as e: | |
| print(f" dot_scaled fp16 failed at {M}x{K}x{N}: {e}", file=sys.stderr) | |
| results["dot_scaled_fp16_ms"] = float("nan") | |
| # --- hl.dot_scaled e4m3 --- | |
| try: | |
| x_e4, xs_e4, y_e4, ys_e4 = make_e4m3_inputs(M, K, N) | |
| kernel_e4m3 = make_dot_scaled_e4m3_kernel(M, K, N) | |
| fn_ds_e4 = functools.partial(kernel_e4m3, x_e4, xs_e4, y_e4, ys_e4) | |
| fn_ds_e4() | |
| times = interleaved_bench([fn_ds_e4], repeat=repeat) | |
| results["dot_scaled_e4m3_ms"] = times[0] | |
| except Exception as e: | |
| print(f" dot_scaled e4m3 failed at {M}x{K}x{N}: {e}", file=sys.stderr) | |
| results["dot_scaled_e4m3_ms"] = float("nan") | |
| # --- nvfp4_matmul (software FP4 dequant, fixed config) --- | |
| try: | |
| A_nv, W_packed_nv = make_nvfp4_inputs(M, K, N) | |
| fn_nvfp4 = functools.partial(nvfp4_matmul_fixed, A_nv, W_packed_nv) | |
| fn_nvfp4() # warmup | |
| times = interleaved_bench([fn_nvfp4], repeat=repeat) | |
| results["nvfp4_matmul_ms"] = times[0] | |
| except Exception as e: | |
| print(f" nvfp4_matmul failed at {M}x{K}x{N}: {e}", file=sys.stderr) | |
| results["nvfp4_matmul_ms"] = float("nan") | |
| return results | |
| def compute_tflops(M: int, K: int, N: int, ms: float) -> float: | |
| if ms <= 0 or ms != ms: | |
| return 0.0 | |
| return 2.0 * M * K * N / (ms * 1e-3) / 1e12 | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| print("=" * 80) | |
| print("Benchmark: hl.dot_scaled vs nvfp4_matmul vs torch._scaled_mm vs torch.matmul") | |
| print(f"dot_scaled config: BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}, BLOCK_K={BLOCK_K}") | |
| print(f"nvfp4 config: BLOCK_MN={NVFP4_BLOCK_MN}, BLOCK_K_PACKED={NVFP4_BLOCK_K_PACKED}") | |
| print("=" * 80) | |
| sizes = [2**p for p in range(8, 15)] # 256..16384 | |
| all_results = [] | |
| for sz in sizes: | |
| M, K, N = sz, sz, sz | |
| min_dim = max(BLOCK_M, BLOCK_N, BLOCK_K) | |
| if sz < min_dim: | |
| print(f"\n--- M={M}, K={K}, N={N} --- SKIPPED (size < {min_dim})") | |
| continue | |
| print(f"\n--- M={M}, K={K}, N={N} ---") | |
| res = bench_one_size(M, K, N) | |
| all_results.append(res) | |
| for key in ALL_KEYS: | |
| ms = res.get(key, float("nan")) | |
| tflops = compute_tflops(M, K, N, ms) | |
| print(f" {key:30s}: {ms:8.4f} ms ({tflops:7.2f} TFLOP/s)") | |
| out_path = "/tmp/dot_scaled_bench_results.json" | |
| with open(out_path, "w") as f: | |
| json.dump(all_results, f, indent=2) | |
| print(f"\nResults saved to {out_path}") | |
| try: | |
| generate_charts(all_results) | |
| except Exception as e: | |
| print(f"Could not generate charts: {e}", file=sys.stderr) | |
| import traceback | |
| traceback.print_exc() | |
| def generate_charts(results: list[dict]): | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| sizes = [r["M"] for r in results] | |
| labels = [str(s) for s in sizes] | |
| torch_fp16 = [r.get("torch_matmul_fp16_ms", float("nan")) for r in results] | |
| torch_smm = [r.get("torch_scaled_mm_e4m3_ms", float("nan")) for r in results] | |
| ds_fp16 = [r.get("dot_scaled_fp16_ms", float("nan")) for r in results] | |
| ds_e4m3 = [r.get("dot_scaled_e4m3_ms", float("nan")) for r in results] | |
| nvfp4 = [r.get("nvfp4_matmul_ms", float("nan")) for r in results] | |
| def tflops_list(key): | |
| return [compute_tflops(r["M"], r["K"], r["N"], r.get(key, float("nan"))) for r in results] | |
| torch_fp16_tf = tflops_list("torch_matmul_fp16_ms") | |
| torch_smm_tf = tflops_list("torch_scaled_mm_e4m3_ms") | |
| ds_fp16_tf = tflops_list("dot_scaled_fp16_ms") | |
| ds_e4m3_tf = tflops_list("dot_scaled_e4m3_ms") | |
| nvfp4_tf = tflops_list("nvfp4_matmul_ms") | |
| x = np.arange(len(labels)) | |
| width = 0.16 | |
| # Chart 1: Latency (log scale) — skip fp16 dot_scaled (Triton bug) | |
| fig, ax = plt.subplots(figsize=(14, 6)) | |
| ax.bar(x - 1.5 * width, torch_fp16, width, label="torch.matmul (fp16)", color="#FF9800") | |
| ax.bar(x - 0.5 * width, torch_smm, width, label="torch._scaled_mm (e4m3)", color="#9C27B0") | |
| ax.bar(x + 0.5 * width, ds_e4m3, width, label="hl.dot_scaled (e4m3)", color="#4CAF50") | |
| ax.bar(x + 1.5 * width, nvfp4, width, label="nvfp4_matmul (sw dequant)", color="#F44336") | |
| ax.set_xlabel("Matrix Size (M=K=N)", fontsize=12) | |
| ax.set_ylabel("Latency (ms)", fontsize=12) | |
| ax.set_title("GEMM Latency on B200 (lower is better)", fontsize=14) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(labels) | |
| ax.legend(fontsize=9) | |
| ax.set_yscale("log") | |
| ax.grid(axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig("/tmp/dot_scaled_latency.png", dpi=150) | |
| print("Saved /tmp/dot_scaled_latency.png") | |
| plt.close() | |
| # Chart 2: Throughput — skip fp16 dot_scaled (Triton bug) | |
| fig2, ax2 = plt.subplots(figsize=(14, 6)) | |
| ax2.plot(labels, torch_fp16_tf, "^-", label="torch.matmul (fp16)", color="#FF9800", linewidth=2, markersize=8) | |
| ax2.plot(labels, torch_smm_tf, "v-", label="torch._scaled_mm (e4m3)", color="#9C27B0", linewidth=2, markersize=8) | |
| ax2.plot(labels, ds_e4m3_tf, "s-", label="hl.dot_scaled (e4m3)", color="#4CAF50", linewidth=2, markersize=8) | |
| ax2.plot(labels, nvfp4_tf, "D-", label="nvfp4_matmul (sw dequant)", color="#F44336", linewidth=2, markersize=8) | |
| ax2.set_xlabel("Matrix Size (M=K=N)", fontsize=12) | |
| ax2.set_ylabel("Throughput (TFLOP/s)", fontsize=12) | |
| ax2.set_title("GEMM Throughput on B200 (higher is better)", fontsize=14) | |
| ax2.legend(fontsize=9) | |
| ax2.grid(alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig("/tmp/dot_scaled_throughput.png", dpi=150) | |
| print("Saved /tmp/dot_scaled_throughput.png") | |
| plt.close() | |
| # Chart 3: Speedup over nvfp4_matmul | |
| fig3, ax3 = plt.subplots(figsize=(14, 6)) | |
| def speedup(baseline, target): | |
| return [b / t if t > 0 and t == t and b > 0 and b == b else 0 for t, b in zip(target, baseline)] | |
| sp_torch = speedup(nvfp4, torch_fp16) | |
| sp_smm = speedup(nvfp4, torch_smm) | |
| sp_ds_e4m3 = speedup(nvfp4, ds_e4m3) | |
| w2 = 0.25 | |
| ax3.bar(x - w2, sp_torch, w2, label="torch.matmul (fp16)", color="#FF9800") | |
| ax3.bar(x, sp_smm, w2, label="torch._scaled_mm (e4m3)", color="#9C27B0") | |
| ax3.bar(x + w2, sp_ds_e4m3, w2, label="hl.dot_scaled (e4m3)", color="#4CAF50") | |
| ax3.axhline(y=1.0, color="red", linestyle="--", alpha=0.7, label="nvfp4_matmul parity") | |
| ax3.set_xlabel("Matrix Size (M=K=N)", fontsize=12) | |
| ax3.set_ylabel("Speedup over nvfp4_matmul (higher is better)", fontsize=12) | |
| ax3.set_title("Speedup over nvfp4_matmul (software FP4 dequant) on B200", fontsize=14) | |
| ax3.set_xticks(x) | |
| ax3.set_xticklabels(labels) | |
| ax3.legend(fontsize=9) | |
| ax3.grid(axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig("/tmp/dot_scaled_speedup.png", dpi=150) | |
| print("Saved /tmp/dot_scaled_speedup.png") | |
| plt.close() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment