Last active
January 8, 2026 18:09
-
-
Save namgyu-youn/aa2fc5d444fdc4b52b35db555087e2ce to your computer and use it in GitHub Desktop.
[TorchAO] GEMM perf: `torch._int_mm` vs. Triton
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torchao.kernel.int8_scaled_mm_triton # noqa: F401 | |
| from torchao.kernel.intmm import int_scaled_matmul | |
| def benchmark_in_ms(warmup, iters, f): | |
| """Benchmark using CUDA events.""" | |
| for _ in range(warmup): | |
| f() | |
| torch.cuda.synchronize() | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| for _ in range(iters): | |
| f() | |
| end.record() | |
| torch.cuda.synchronize() | |
| return start.elapsed_time(end) / iters | |
| def run_benchmark(m, k, n, warmup=10, iters=100): | |
| """Run benchmark for shape [m,k] @ [k,n].""" | |
| print(f"\n[{m}, {k}] @ [{k}, {n}]") | |
| # Prepare inputs | |
| a_fp = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) | |
| b_fp = torch.randn(k, n, device="cuda", dtype=torch.bfloat16) | |
| scale_a = a_fp.abs().max(dim=1, keepdim=True)[0] / 127.0 | |
| scale_b = b_fp.abs().max(dim=0, keepdim=True)[0] / 127.0 | |
| a_int8 = (a_fp / scale_a).round().clamp(-128, 127).to(torch.int8) | |
| b_int8 = (b_fp / scale_b).round().clamp(-128, 127).to(torch.int8) | |
| # FP16 baseline | |
| fp16_time = benchmark_in_ms(warmup, iters, lambda: torch.mm(a_fp, b_fp)) | |
| print(f" FP16: {fp16_time:.3f}ms") | |
| # PyTorch int_scaled_matmul | |
| @torch.compile(mode="max-autotune") | |
| def pytorch_kernel(): | |
| return int_scaled_matmul(a_int8, b_int8, scale_a) | |
| pytorch_kernel() # warmup | |
| pytorch_time = benchmark_in_ms(warmup, iters, pytorch_kernel) | |
| print(f" PyTorch: {pytorch_time:.3f}ms ({fp16_time / pytorch_time:.2f}x)") | |
| # Triton scaled_int8_mm | |
| triton_kernel_fn = torch.ops.torchao.scaled_int8_mm | |
| @torch.compile(mode="max-autotune") | |
| def triton_kernel(): | |
| return triton_kernel_fn(a_int8, b_int8, scale_a, scale_b) | |
| triton_kernel() # warmup | |
| triton_time = benchmark_in_ms(warmup, iters, triton_kernel) | |
| print(f" Triton: {triton_time:.3f}ms ({fp16_time / triton_time:.2f}x)") | |
| print(f" -> Triton {pytorch_time / triton_time:.2f}x faster") | |
| def main(): | |
| scenarios = [ | |
| (32, 4096, 4096), | |
| (4096, 4096, 4096), | |
| (16384, 8192, 5120), | |
| ] | |
| for m, k, n in scenarios: | |
| run_benchmark(m, k, n) | |
| if __name__ == "__main__": | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Result on A100: