Last active
August 15, 2024 22:06
-
-
Save functionstackx/c980e57f3d6d5959a0d9f69f800d8db0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| import torch | |
| import torch.utils.benchmark as benchmark | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # patch of https://github.com/triton-lang/triton/pull/4493 | |
| def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): | |
| """ | |
| Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with | |
| the 20-th and 80-th performance percentile. | |
| :param fn: Function to benchmark | |
| :type fn: Callable | |
| :param warmup: Warmup time (in ms) | |
| :type warmup: int | |
| :param rep: Repetition time (in ms) | |
| :type rep: int | |
| :param grad_to_none: Reset the gradient of the provided tensor to None | |
| :type grad_to_none: torch.tensor, optional | |
| :param quantiles: Performance percentile to return in addition to the median. | |
| :type quantiles: list[float], optional | |
| :param fast_flush: Use faster kernel to flush L2 cache between measurements | |
| :type fast_flush: bool, default is True | |
| :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str | |
| """ | |
| assert return_mode in ["min", "max", "mean", "median", "all"] | |
| import torch | |
| fn() | |
| torch.cuda.synchronize() | |
| # We maintain a buffer of 256 MB that we clear | |
| # before each kernel call to make sure that the L2 cache | |
| # doesn't contain any input data before the run | |
| cache_size = 256 * 1024 * 1024 | |
| if fast_flush: | |
| cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') | |
| else: | |
| cache = torch.empty(int(cache_size), dtype=torch.int8, device='cuda') | |
| # Estimate the runtime of the function | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| for _ in range(5): | |
| cache.zero_() | |
| fn() | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| estimate_ms = start_event.elapsed_time(end_event) / 5 | |
| # compute number of warmup and repeat | |
| n_warmup = max(1, int(warmup / estimate_ms)) | |
| n_repeat = rep | |
| start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] | |
| end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] | |
| # Warm-up | |
| for _ in range(n_warmup): | |
| fn() | |
| # Benchmark | |
| for i in range(n_repeat): | |
| # we don't want `fn` to accumulate gradient values | |
| # if it contains a backward pass. So we clear the | |
| # provided gradients | |
| if grad_to_none is not None: | |
| for x in grad_to_none: | |
| x.grad = None | |
| # we clear the L2 cache before each run | |
| cache.zero_() | |
| # record time of `fn` | |
| start_event[i].record() | |
| fn() | |
| end_event[i].record() | |
| # Record clocks | |
| torch.cuda.synchronize() | |
| times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) | |
| return times.tolist() | |
| def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, nFLOPS_matmul=0, window_size=1000, **kwinputs): | |
| latencies = do_bench(lambda: fn(*inputs, **kwinputs), warmup=0, rep=repeats, fast_flush=True) | |
| print(len(latencies)) | |
| tflops = [] | |
| for latency in latencies: | |
| tflops.append(nFLOPS_matmul / latency * 1e6) | |
| # Repeat values | |
| repeat_indices = list(range(1, len(tflops) + 1)) | |
| # Calculate moving window average | |
| moving_avg = np.convolve(tflops, np.ones(window_size)/window_size, mode='valid') | |
| # Plotting the data | |
| plt.figure(figsize=(50, 20)) | |
| plt.plot(repeat_indices, tflops, marker='o', linestyle='-', color='b', label='TFLOP/s') | |
| # Plotting the moving average (starting from window_size to len(tflops)) | |
| plt.plot(repeat_indices[window_size-1:], moving_avg, color='r', linestyle='--', label=f'{window_size}-iter Moving Avg') | |
| plt.xlabel('Iteration') | |
| plt.ylabel('TFLOP/s') | |
| plt.title('TFLOP/s vs Iteration with Moving Average') | |
| plt.grid(True) | |
| plt.legend() | |
| plt.savefig('tflops_with_moving_avg.png') | |
| # save raw data as csv with header "iter, tflops" | |
| with open('tflops.csv', 'w') as f: | |
| f.write('iter, tflops\n') | |
| for i, tflop in enumerate(tflops): | |
| f.write(f'{i+1}, {tflop}\n') | |
| torch.manual_seed(0) | |
| repeats = 2000 | |
| dtype = torch.bfloat16 | |
| device = 'cuda' | |
| verbose = False | |
| m, n = 8192, 8192 | |
| tflops_matmul = {} | |
| tflops_matmul1 = {} | |
| for k in [8192]: | |
| a = torch.randn(m, k, device=device, dtype=dtype) | |
| b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2) | |
| c = torch.empty(m, n, device=device, dtype=dtype) | |
| nFLOPS_matmul = 2 * m * n * k | |
| time.sleep(2) # to reduce power throttling | |
| benchmark_forward(torch.matmul, a, b, desc='torch.matmul', verbose=verbose, repeats=repeats, nFLOPS_matmul=nFLOPS_matmul, out=c) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment