Skip to content

Instantly share code, notes, and snippets.

@functionstackx
Last active August 15, 2024 22:06
Show Gist options
  • Select an option

  • Save functionstackx/c980e57f3d6d5959a0d9f69f800d8db0 to your computer and use it in GitHub Desktop.

Select an option

Save functionstackx/c980e57f3d6d5959a0d9f69f800d8db0 to your computer and use it in GitHub Desktop.
import time
import torch
import torch.utils.benchmark as benchmark
import matplotlib.pyplot as plt
import numpy as np
# patch of https://github.com/triton-lang/triton/pull/4493
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float], optional
:param fast_flush: Use faster kernel to flush L2 cache between measurements
:type fast_flush: bool, default is True
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str
"""
assert return_mode in ["min", "max", "mean", "median", "all"]
import torch
fn()
torch.cuda.synchronize()
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2 cache
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
if fast_flush:
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(cache_size), dtype=torch.int8, device='cuda')
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = rep
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
return times.tolist()
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, nFLOPS_matmul=0, window_size=1000, **kwinputs):
latencies = do_bench(lambda: fn(*inputs, **kwinputs), warmup=0, rep=repeats, fast_flush=True)
print(len(latencies))
tflops = []
for latency in latencies:
tflops.append(nFLOPS_matmul / latency * 1e6)
# Repeat values
repeat_indices = list(range(1, len(tflops) + 1))
# Calculate moving window average
moving_avg = np.convolve(tflops, np.ones(window_size)/window_size, mode='valid')
# Plotting the data
plt.figure(figsize=(50, 20))
plt.plot(repeat_indices, tflops, marker='o', linestyle='-', color='b', label='TFLOP/s')
# Plotting the moving average (starting from window_size to len(tflops))
plt.plot(repeat_indices[window_size-1:], moving_avg, color='r', linestyle='--', label=f'{window_size}-iter Moving Avg')
plt.xlabel('Iteration')
plt.ylabel('TFLOP/s')
plt.title('TFLOP/s vs Iteration with Moving Average')
plt.grid(True)
plt.legend()
plt.savefig('tflops_with_moving_avg.png')
# save raw data as csv with header "iter, tflops"
with open('tflops.csv', 'w') as f:
f.write('iter, tflops\n')
for i, tflop in enumerate(tflops):
f.write(f'{i+1}, {tflop}\n')
torch.manual_seed(0)
repeats = 2000
dtype = torch.bfloat16
device = 'cuda'
verbose = False
m, n = 8192, 8192
tflops_matmul = {}
tflops_matmul1 = {}
for k in [8192]:
a = torch.randn(m, k, device=device, dtype=dtype)
b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
c = torch.empty(m, n, device=device, dtype=dtype)
nFLOPS_matmul = 2 * m * n * k
time.sleep(2) # to reduce power throttling
benchmark_forward(torch.matmul, a, b, desc='torch.matmul', verbose=verbose, repeats=repeats, nFLOPS_matmul=nFLOPS_matmul, out=c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment