Skip to content

Instantly share code, notes, and snippets.

@functionstackx
Last active August 8, 2024 19:05
Show Gist options
  • Select an option

  • Save functionstackx/7a6583953ec340ba13a7a95b3c9a3503 to your computer and use it in GitHub Desktop.

Select an option

Save functionstackx/7a6583953ec340ba13a7a95b3c9a3503 to your computer and use it in GitHub Desktop.
matmul sweep
#DISABLE_ADDMM_HIP_LT=0
#PYTORCH_TUNABLEOP_ENABLED=1
# set env flag
import os
os.environ["DISABLE_ADDMM_HIP_LT"] = "0"
os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1"
import time
import torch
import torch.utils.benchmark as benchmark
from triton.testing import do_bench
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
if verbose:
print(desc, '- Forward pass')
t = benchmark.Timer(
stmt='fn(*inputs, **kwinputs)',
globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
if verbose:
print(m)
return t, m
torch.manual_seed(0)
dtype = torch.bfloat16
device = 'cuda'
verbose = False
test_case = [(4352, 3840, 13568), (4352, 13568, 3840),
(6144, 17920, 2816), (6144, 2816, 17920),
(8192,8192,8192)
]
for repeats in [30, 100, 1000]:
for distribution in ["zero", "randn"]:
for m, n, k in test_case:
if distribution == "zero":
a = torch.zeros(m, k, device=device, dtype=dtype)
b = torch.zeros(n, k, device=device, dtype=dtype).t()
elif distribution == "randn":
a = torch.randn(m, k, device=device, dtype=dtype)
b = torch.randn(n, k, device=device, dtype=dtype).t()
C = torch.empty(m, n, dtype=dtype, device=device)
nFLOPS_matmul = 2 * m * n * k
time.sleep(3) # to reduce power throttling
timing = benchmark_forward(torch.matmul, a, b, desc='', verbose=verbose, repeats=repeats)[1]
tflops_matmul = nFLOPS_matmul / timing.mean * 1e-12
print(f'[torch.utils.benchmark] torch.matmul, {repeats = }, {distribution = }, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul:.1f} TFLOPS')
time.sleep(3) # to reduce power throttling
ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
tflops_matmul1 = nFLOPS_matmul / ms * 1e-9
print(f'[triton.test.do_bench] torch.matmul, {repeats = }, {distribution = }, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1:.1f} TFLOPS')
@functionstackx
Copy link
Author

functionstackx commented Aug 8, 2024

8192x8192x8192 bf16 randn distribution

Repeat 30 100 1000
H100 760 731 679
Mi300X 593 592 589

H100

[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 4352, n = 3840, k = 13568: 0.501ms, 904.5 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 4352, n = 3840, k = 13568: 0.502ms, 903.8 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 4352, n = 13568, k = 3840: 0.547ms, 828.8 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 4352, n = 13568, k = 3840: 0.518ms, 875.8 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 6144, n = 17920, k = 2816: 0.784ms, 790.9 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 6144, n = 17920, k = 2816: 0.765ms, 810.7 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 6144, n = 2816, k = 17920: 0.709ms, 874.8 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 6144, n = 2816, k = 17920: 0.657ms, 943.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 8192, n = 8192, k = 8192: 1.310ms, 839.2 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 8192, n = 8192, k = 8192: 1.300ms, 845.9 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 4352, n = 3840, k = 13568: 0.577ms, 786.5 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 4352, n = 3840, k = 13568: 0.577ms, 785.5 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 4352, n = 13568, k = 3840: 0.595ms, 762.8 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 4352, n = 13568, k = 3840: 0.600ms, 756.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 6144, n = 17920, k = 2816: 0.843ms, 735.9 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 6144, n = 17920, k = 2816: 0.860ms, 720.9 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 6144, n = 2816, k = 17920: 0.764ms, 811.2 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 6144, n = 2816, k = 17920: 0.772ms, 803.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 8192, n = 8192, k = 8192: 1.455ms, 755.6 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 8192, n = 8192, k = 8192: 1.446ms, 760.1 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 4352, n = 3840, k = 13568: 0.512ms, 885.1 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 4352, n = 3840, k = 13568: 0.510ms, 889.3 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 4352, n = 13568, k = 3840: 0.524ms, 866.2 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 4352, n = 13568, k = 3840: 0.514ms, 882.8 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 6144, n = 17920, k = 2816: 0.780ms, 795.0 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 6144, n = 17920, k = 2816: 0.766ms, 809.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 6144, n = 2816, k = 17920: 0.690ms, 898.5 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 6144, n = 2816, k = 17920: 0.672ms, 923.1 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 8192, n = 8192, k = 8192: 1.323ms, 831.2 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 8192, n = 8192, k = 8192: 1.309ms, 840.0 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 4352, n = 3840, k = 13568: 0.596ms, 761.0 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 4352, n = 3840, k = 13568: 0.609ms, 744.5 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 4352, n = 13568, k = 3840: 0.619ms, 732.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 4352, n = 13568, k = 3840: 0.619ms, 732.6 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 6144, n = 17920, k = 2816: 0.902ms, 687.3 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 6144, n = 17920, k = 2816: 0.902ms, 687.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 6144, n = 2816, k = 17920: 0.800ms, 774.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 6144, n = 2816, k = 17920: 0.813ms, 762.5 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 8192, n = 8192, k = 8192: 1.485ms, 740.4 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 8192, n = 8192, k = 8192: 1.503ms, 731.3 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 4352, n = 3840, k = 13568: 0.520ms, 872.0 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 4352, n = 3840, k = 13568: 0.496ms, 914.0 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 4352, n = 13568, k = 3840: 0.529ms, 857.6 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 4352, n = 13568, k = 3840: 0.513ms, 884.1 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 6144, n = 17920, k = 2816: 0.776ms, 799.2 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 6144, n = 17920, k = 2816: 0.760ms, 816.0 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 6144, n = 2816, k = 17920: 0.689ms, 900.2 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 6144, n = 2816, k = 17920: 0.660ms, 940.2 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 8192, n = 8192, k = 8192: 1.292ms, 850.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 8192, n = 8192, k = 8192: 1.323ms, 831.1 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 4352, n = 3840, k = 13568: 0.634ms, 715.8 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 4352, n = 3840, k = 13568: 0.662ms, 685.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 4352, n = 13568, k = 3840: 0.655ms, 692.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 4352, n = 13568, k = 3840: 0.683ms, 663.8 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 6144, n = 17920, k = 2816: 0.946ms, 655.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 6144, n = 17920, k = 2816: 0.959ms, 646.8 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 6144, n = 2816, k = 17920: 0.836ms, 741.8 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 6144, n = 2816, k = 17920: 0.876ms, 707.5 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 8192, n = 8192, k = 8192: 1.629ms, 674.9 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 8192, n = 8192, k = 8192: 1.618ms, 679.5 TFLOPS

MI300X

[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 4352, n = 3840, k = 13568: 0.640ms, 708.1 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 4352, n = 3840, k = 13568: 0.848ms, 534.8 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 4352, n = 13568, k = 3840: 0.676ms, 670.4 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 4352, n = 13568, k = 3840: 0.708ms, 640.6 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 6144, n = 17920, k = 2816: 0.944ms, 656.9 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 6144, n = 17920, k = 2816: 0.962ms, 644.8 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 6144, n = 2816, k = 17920: 1.037ms, 597.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 6144, n = 2816, k = 17920: 1.111ms, 558.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'zero',  m = 8192, n = 8192, k = 8192: 1.377ms, 798.4 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'zero', m = 8192, n = 8192, k = 8192: 1.424ms, 772.2 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 4352, n = 3840, k = 13568: 0.838ms, 540.9 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 4352, n = 3840, k = 13568: 1.075ms, 421.9 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 4352, n = 13568, k = 3840: 0.885ms, 512.6 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 4352, n = 13568, k = 3840: 0.930ms, 487.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 6144, n = 17920, k = 2816: 1.222ms, 507.3 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 6144, n = 17920, k = 2816: 1.274ms, 486.6 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 6144, n = 2816, k = 17920: 1.288ms, 481.3 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 6144, n = 2816, k = 17920: 1.396ms, 444.3 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 30, distribution = 'randn',  m = 8192, n = 8192, k = 8192: 1.797ms, 611.8 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 30,  distribution = 'randn', m = 8192, n = 8192, k = 8192: 1.851ms, 593.9 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 4352, n = 3840, k = 13568: 0.624ms, 726.6 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 4352, n = 3840, k = 13568: 0.847ms, 535.1 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 4352, n = 13568, k = 3840: 0.651ms, 696.2 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 4352, n = 13568, k = 3840: 0.708ms, 640.3 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 6144, n = 17920, k = 2816: 0.902ms, 687.3 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 6144, n = 17920, k = 2816: 0.959ms, 646.5 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 6144, n = 2816, k = 17920: 0.997ms, 622.1 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 6144, n = 2816, k = 17920: 1.110ms, 558.7 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'zero',  m = 8192, n = 8192, k = 8192: 1.342ms, 819.1 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'zero', m = 8192, n = 8192, k = 8192: 1.422ms, 773.2 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 4352, n = 3840, k = 13568: 0.827ms, 548.3 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 4352, n = 3840, k = 13568: 1.076ms, 421.6 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 4352, n = 13568, k = 3840: 0.875ms, 518.5 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 4352, n = 13568, k = 3840: 0.933ms, 486.1 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 6144, n = 17920, k = 2816: 1.212ms, 511.8 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 6144, n = 17920, k = 2816: 1.270ms, 488.2 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 6144, n = 2816, k = 17920: 1.263ms, 491.1 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 6144, n = 2816, k = 17920: 1.406ms, 441.0 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 100, distribution = 'randn',  m = 8192, n = 8192, k = 8192: 1.786ms, 615.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 100,  distribution = 'randn', m = 8192, n = 8192, k = 8192: 1.855ms, 592.9 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 4352, n = 3840, k = 13568: 0.618ms, 733.3 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 4352, n = 3840, k = 13568: 0.848ms, 535.0 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 4352, n = 13568, k = 3840: 0.645ms, 702.9 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 4352, n = 13568, k = 3840: 0.708ms, 640.3 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 6144, n = 17920, k = 2816: 0.895ms, 693.0 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 6144, n = 17920, k = 2816: 0.962ms, 644.5 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 6144, n = 2816, k = 17920: 0.983ms, 630.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 6144, n = 2816, k = 17920: 1.110ms, 558.5 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'zero',  m = 8192, n = 8192, k = 8192: 1.339ms, 820.9 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'zero', m = 8192, n = 8192, k = 8192: 1.433ms, 767.2 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 4352, n = 3840, k = 13568: 0.825ms, 549.8 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 4352, n = 3840, k = 13568: 1.066ms, 425.2 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 4352, n = 13568, k = 3840: 0.874ms, 519.0 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 4352, n = 13568, k = 3840: 0.899ms, 504.4 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 6144, n = 17920, k = 2816: 1.214ms, 510.7 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 6144, n = 17920, k = 2816: 1.226ms, 505.9 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 6144, n = 2816, k = 17920: 1.257ms, 493.2 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 6144, n = 2816, k = 17920: 1.404ms, 441.5 TFLOPS
[torch.utils.benchmark] torch.matmul, repeats = 1000, distribution = 'randn',  m = 8192, n = 8192, k = 8192: 1.794ms, 612.9 TFLOPS
[triton.test.do_bench] torch.matmul, repeats = 1000,  distribution = 'randn', m = 8192, n = 8192, k = 8192: 1.866ms, 589.3 TFLOPS

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment