Last active
September 23, 2025 22:24
-
-
Save davidberard98/b97e834e36fa9ee49a016a38aee3f182 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import multiprocessing | |
| import os | |
| from time import sleep | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| def get_num_bytes(*args): | |
| num_bytes = sum( | |
| (x.numel() * x.element_size() for x in args if isinstance(x, torch.Tensor)) | |
| ) | |
| return num_bytes | |
| @triton.autotune( | |
| [triton.Config({"ROW_BLOCK_SIZE": 8}, num_warps=2, num_stages=3)], key=[] | |
| ) | |
| @triton.jit | |
| def kernel_layernorm_2d( | |
| X, | |
| Y, | |
| stride, | |
| M, | |
| N, | |
| eps, | |
| ROW_BLOCK_SIZE: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE) | |
| X = X + row * stride | |
| Y = Y + row * stride | |
| mask_row = row < M | |
| cols = tl.arange(0, BLOCK_SIZE) | |
| mask = cols < N | |
| x = tl.load( | |
| X[:, None] + cols[None, :], mask=mask_row[:, None] & mask[None, :], other=0.0 | |
| ).to(tl.float32) | |
| mean = tl.sum(x, axis=1) / N | |
| var = tl.sum((x - mean[:, None]) * (x - mean[:, None]), axis=1) / N | |
| rstd = 1 / tl.sqrt(var + eps) | |
| y_hat = (x - mean[:, None]) * rstd[:, None] | |
| y = y_hat | |
| tl.store(Y[:, None] + cols[None, :], y, mask=mask_row[:, None] & mask[None, :]) | |
| def triton_layernorm_2d(x, eps, *, return_rstd=True, return_mean=True): | |
| assert return_rstd and return_mean | |
| M, N = x.size() | |
| out = torch.empty_like(x) | |
| BLOCK_SIZE = triton.next_power_of_2(N) | |
| def grid(meta): | |
| return (triton.cdiv(M, meta["ROW_BLOCK_SIZE"]),) | |
| kernel_layernorm_2d[grid](x, out, x.stride(0), M, N, eps, BLOCK_SIZE=BLOCK_SIZE) | |
| return out | |
| def benchmark_under_load(fn, cache_clearer=False, warmup_reps=10000, timing_reps=10000): | |
| assert not cache_clearer # not implemented | |
| begin_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| for _ in range(warmup_reps): | |
| fn() | |
| begin_event.record() | |
| for _ in range(timing_reps): | |
| fn() | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| return begin_event.elapsed_time(end_event) / timing_reps | |
| def main(): | |
| M, K = 2**20, 512 | |
| x = (torch.rand(M, K, device="cuda") - 0.5) * 2 * 50000 | |
| x = x.to(torch.bfloat16) | |
| eps = 1e-3 | |
| def fn(): | |
| return triton_layernorm_2d(x, eps) | |
| ms = benchmark_under_load(fn) | |
| def gbps(ms): | |
| return get_num_bytes(x, fn()) / ms * 1e-6 | |
| print( | |
| f"Perf: {ms:.3f} ms ({gbps(ms)} GB/s)" | |
| ) | |
| print() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment