Last active
August 8, 2024 19:05
-
-
Save functionstackx/7a6583953ec340ba13a7a95b3c9a3503 to your computer and use it in GitHub Desktop.
matmul sweep
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #DISABLE_ADDMM_HIP_LT=0 | |
| #PYTORCH_TUNABLEOP_ENABLED=1 | |
| # set env flag | |
| import os | |
| os.environ["DISABLE_ADDMM_HIP_LT"] = "0" | |
| os.environ["PYTORCH_TUNABLEOP_ENABLED"] = "1" | |
| import time | |
| import torch | |
| import torch.utils.benchmark as benchmark | |
| from triton.testing import do_bench | |
| def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs): | |
| """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" | |
| if verbose: | |
| print(desc, '- Forward pass') | |
| t = benchmark.Timer( | |
| stmt='fn(*inputs, **kwinputs)', | |
| globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs}, | |
| num_threads=torch.get_num_threads(), | |
| ) | |
| m = t.timeit(repeats) | |
| if verbose: | |
| print(m) | |
| return t, m | |
| torch.manual_seed(0) | |
| dtype = torch.bfloat16 | |
| device = 'cuda' | |
| verbose = False | |
| test_case = [(4352, 3840, 13568), (4352, 13568, 3840), | |
| (6144, 17920, 2816), (6144, 2816, 17920), | |
| (8192,8192,8192) | |
| ] | |
| for repeats in [30, 100, 1000]: | |
| for distribution in ["zero", "randn"]: | |
| for m, n, k in test_case: | |
| if distribution == "zero": | |
| a = torch.zeros(m, k, device=device, dtype=dtype) | |
| b = torch.zeros(n, k, device=device, dtype=dtype).t() | |
| elif distribution == "randn": | |
| a = torch.randn(m, k, device=device, dtype=dtype) | |
| b = torch.randn(n, k, device=device, dtype=dtype).t() | |
| C = torch.empty(m, n, dtype=dtype, device=device) | |
| nFLOPS_matmul = 2 * m * n * k | |
| time.sleep(3) # to reduce power throttling | |
| timing = benchmark_forward(torch.matmul, a, b, desc='', verbose=verbose, repeats=repeats)[1] | |
| tflops_matmul = nFLOPS_matmul / timing.mean * 1e-12 | |
| print(f'[torch.utils.benchmark] torch.matmul, {repeats = }, {distribution = }, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul:.1f} TFLOPS') | |
| time.sleep(3) # to reduce power throttling | |
| ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats) | |
| tflops_matmul1 = nFLOPS_matmul / ms * 1e-9 | |
| print(f'[triton.test.do_bench] torch.matmul, {repeats = }, {distribution = }, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1:.1f} TFLOPS') | |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
8192x8192x8192 bf16 randn distribution
H100
MI300X