Skip to content

Instantly share code, notes, and snippets.

@namgyu-youn
Last active January 8, 2026 18:09
Show Gist options
  • Select an option

  • Save namgyu-youn/aa2fc5d444fdc4b52b35db555087e2ce to your computer and use it in GitHub Desktop.

Select an option

Save namgyu-youn/aa2fc5d444fdc4b52b35db555087e2ce to your computer and use it in GitHub Desktop.
[TorchAO] GEMM perf: `torch._int_mm` vs. Triton
import torch
import torchao.kernel.int8_scaled_mm_triton # noqa: F401
from torchao.kernel.intmm import int_scaled_matmul
def benchmark_in_ms(warmup, iters, f):
"""Benchmark using CUDA events."""
for _ in range(warmup):
f()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
f()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters
def run_benchmark(m, k, n, warmup=10, iters=100):
"""Run benchmark for shape [m,k] @ [k,n]."""
print(f"\n[{m}, {k}] @ [{k}, {n}]")
# Prepare inputs
a_fp = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
b_fp = torch.randn(k, n, device="cuda", dtype=torch.bfloat16)
scale_a = a_fp.abs().max(dim=1, keepdim=True)[0] / 127.0
scale_b = b_fp.abs().max(dim=0, keepdim=True)[0] / 127.0
a_int8 = (a_fp / scale_a).round().clamp(-128, 127).to(torch.int8)
b_int8 = (b_fp / scale_b).round().clamp(-128, 127).to(torch.int8)
# FP16 baseline
fp16_time = benchmark_in_ms(warmup, iters, lambda: torch.mm(a_fp, b_fp))
print(f" FP16: {fp16_time:.3f}ms")
# PyTorch int_scaled_matmul
@torch.compile(mode="max-autotune")
def pytorch_kernel():
return int_scaled_matmul(a_int8, b_int8, scale_a)
pytorch_kernel() # warmup
pytorch_time = benchmark_in_ms(warmup, iters, pytorch_kernel)
print(f" PyTorch: {pytorch_time:.3f}ms ({fp16_time / pytorch_time:.2f}x)")
# Triton scaled_int8_mm
triton_kernel_fn = torch.ops.torchao.scaled_int8_mm
@torch.compile(mode="max-autotune")
def triton_kernel():
return triton_kernel_fn(a_int8, b_int8, scale_a, scale_b)
triton_kernel() # warmup
triton_time = benchmark_in_ms(warmup, iters, triton_kernel)
print(f" Triton: {triton_time:.3f}ms ({fp16_time / triton_time:.2f}x)")
print(f" -> Triton {pytorch_time / triton_time:.2f}x faster")
def main():
scenarios = [
(32, 4096, 4096),
(4096, 4096, 4096),
(16384, 8192, 5120),
]
for m, k, n in scenarios:
run_benchmark(m, k, n)
if __name__ == "__main__":
main()
@namgyu-youn
Copy link
Author

namgyu-youn commented Dec 31, 2025

Result on A100:

[32, 4096] @ [4096, 4096]
  FP16:                0.174ms
  PyTorch:  0.751ms  (0.23x)
  Triton:   0.292ms  (0.59x)
  -> Triton 2.57x faster

[4096, 4096] @ [4096, 4096]
  FP16:                3.763ms
  PyTorch:  14.969ms  (0.25x)
  Triton:   4.098ms  (0.92x)
  -> Triton 3.65x faster

[16384, 8192] @ [8192, 5120]
  FP16:                36.921ms
  PyTorch:  143.929ms  (0.26x)
  Triton:   37.194ms  (0.99x)
  -> Triton 3.87x faster

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment