Created
September 5, 2025 08:51
-
-
Save juvi21/2f3fbe8cb0809adeb3f977894bb36812 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| import math | |
| import argparse | |
| from typing import Callable, Dict, Any, Tuple, List | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| # ------------------------- | |
| # Environment helpers | |
| # ------------------------- | |
| def env_info() -> str: | |
| try: | |
| dev = torch.cuda.get_device_name(0) | |
| cap = torch.cuda.get_device_capability(0) | |
| cuda_ver = getattr(torch.version, 'cuda', 'unknown') | |
| except Exception: | |
| dev, cap, cuda_ver = "<unknown>", (0, 0), 'unknown' | |
| return ( | |
| f"Torch {torch.__version__}, CUDA {cuda_ver}, " | |
| f"Triton {triton.__version__}, GPU {dev}, CC {cap}" | |
| ) | |
| # ------------------------- | |
| # Kernels | |
| # ------------------------- | |
| @triton.jit | |
| def u64_or(a, b): | |
| return a | b | |
| @triton.jit | |
| def standard_reduce_kernel( | |
| input_ptr, output_ptr, | |
| M, N, K, | |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, | |
| ): | |
| pid_m = tl.program_id(0) | |
| pid_n = tl.program_id(1) | |
| m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
| k = tl.arange(0, BLOCK_K) | |
| offs = (m[:, None, None] * N * K + | |
| n[None, :, None] * K + | |
| k[None, None, :]) | |
| mask3 = ((m[:, None, None] < M) & | |
| (n[None, :, None] < N) & | |
| (k[None, None, :] < K)) | |
| tile = tl.load(input_ptr + offs, mask=mask3, other=0) | |
| out = tl.reduce(tile, axis=2, combine_fn=u64_or) | |
| out_offs = m[:, None] * N + n[None, :] | |
| out_mask = (m[:, None] < M) & (n[None, :] < N) | |
| tl.store(output_ptr + out_offs, out, mask=out_mask) | |
| @triton.jit | |
| def allocate_only_sram_kernel( | |
| input_ptr, output_ptr, | |
| M, N, K, | |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, | |
| ): | |
| pid_m = tl.program_id(0) | |
| pid_n = tl.program_id(1) | |
| m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
| # Allocate a local 3D tile to mirror the reduction tile shape | |
| tmp = tl.zeros((BLOCK_M, BLOCK_N, BLOCK_K), dtype=tl.int64) | |
| # Consume to avoid DCE; result remains zero | |
| res = tl.sum(tmp, axis=2) | |
| out_offs = m[:, None] * N + n[None, :] | |
| out_mask = (m[:, None] < M) & (n[None, :] < N) | |
| tl.store(output_ptr + out_offs, res, mask=out_mask) | |
| @triton.jit | |
| def preload_no_reduce_kernel( | |
| input_ptr, output_ptr, | |
| M, N, K, | |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, | |
| ): | |
| pid_m = tl.program_id(0) | |
| pid_n = tl.program_id(1) | |
| m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
| k = tl.arange(0, BLOCK_K) | |
| offs = (m[:, None, None] * N * K + | |
| n[None, :, None] * K + | |
| k[None, None, :]) | |
| mask3 = ((m[:, None, None] < M) & | |
| (n[None, :, None] < N) & | |
| (k[None, None, :] < K)) | |
| tile = tl.load(input_ptr + offs, mask=mask3, other=0) | |
| # Produce output from k=0 via direct load (avoid local slicing issues) | |
| m2 = m[:, None] | |
| n2 = n[None, :] | |
| mask2 = (m2 < M) & (n2 < N) | |
| offs2 = m2 * N * K + n2 * K + 0 | |
| res = tl.load(input_ptr + offs2, mask=mask2, other=0) | |
| out_offs = m2 * N + n2 | |
| tl.store(output_ptr + out_offs, res, mask=mask2) | |
| def build_manual_unroll_kernel(K: int) -> Callable: | |
| """Dynamically build a manual-unrolled kernel (global loads) for a given K.""" | |
| if K < 1: | |
| raise ValueError("K must be >= 1") | |
| lines = ["res = tl.load(input_ptr + input_base + 0, mask=mask, other=0)"] | |
| for i in range(1, K): | |
| lines.append( | |
| f"res = res | tl.load(input_ptr + input_base + {i}, mask=mask, other=0)" | |
| ) | |
| body = " " + "\n ".join(lines) | |
| src = f""" | |
| @triton.jit | |
| def manual_unroll_k{K}_kernel( | |
| input_ptr, output_ptr, | |
| M, N, K, | |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, | |
| ): | |
| pid_m = tl.program_id(0) | |
| pid_n = tl.program_id(1) | |
| m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
| m2 = m[:, None] | |
| n2 = n[None, :] | |
| mask = (m2 < M) & (n2 < N) | |
| input_base = m2 * N * K + n2 * K | |
| # Manual unroll (global loads) | |
| {body} | |
| out_offs = m2 * N + n2 | |
| tl.store(output_ptr + out_offs, res, mask=mask) | |
| """ | |
| # Write to a temp file so Triton can retrieve source for JIT | |
| import importlib.util, sys, os | |
| mod_name = f"manual_unroll_k{K}_kernel_mod" | |
| file_path = os.path.abspath(f"_manual_unroll_k{K}_kernel.py") | |
| with open(file_path, "w") as f: | |
| f.write("import triton\nimport triton.language as tl\n\n") | |
| f.write(src) | |
| spec = importlib.util.spec_from_file_location(mod_name, file_path) | |
| mod = importlib.util.module_from_spec(spec) | |
| sys.modules[mod_name] = mod | |
| assert spec.loader is not None | |
| spec.loader.exec_module(mod) | |
| return getattr(mod, f"manual_unroll_k{K}_kernel") | |
| # ------------------------- | |
| # Wrappers | |
| # ------------------------- | |
| def _grid(M: int, N: int, BM: int, BN: int) -> Tuple[int, int]: | |
| return (triton.cdiv(M, BM), triton.cdiv(N, BN)) | |
| def run_standard_reduce(x: torch.Tensor, *, BM=64, BN=64, num_warps=4, num_stages=2) -> torch.Tensor: | |
| M, N, K = x.shape | |
| out = torch.empty((M, N), dtype=x.dtype, device=x.device) | |
| standard_reduce_kernel[_grid(M, N, BM, BN)]( | |
| x, out, M, N, K, | |
| BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=K, | |
| num_warps=num_warps, num_stages=num_stages, | |
| ) | |
| return out | |
| _MANUAL_KERNEL_CACHE: Dict[int, Callable] = {} | |
| def run_manual_unroll(x: torch.Tensor, *, BM=64, BN=64, num_warps=4, num_stages=2) -> torch.Tensor: | |
| M, N, K = x.shape | |
| assert int(K) == K, "K must be an integer" | |
| k_int = int(K) | |
| kern = _MANUAL_KERNEL_CACHE.get(k_int) | |
| if kern is None: | |
| kern = build_manual_unroll_kernel(k_int) | |
| _MANUAL_KERNEL_CACHE[k_int] = kern | |
| out = torch.empty((M, N), dtype=x.dtype, device=x.device) | |
| kern[_grid(M, N, BM, BN)]( | |
| x, out, M, N, K, | |
| BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=K, | |
| num_warps=num_warps, num_stages=num_stages, | |
| ) | |
| return out | |
| def run_allocate_only(x: torch.Tensor, *, BM=64, BN=64, num_warps=4, num_stages=2) -> torch.Tensor: | |
| M, N, K = x.shape | |
| out = torch.empty((M, N), dtype=x.dtype, device=x.device) | |
| allocate_only_sram_kernel[_grid(M, N, BM, BN)]( | |
| x, out, M, N, K, | |
| BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=K, | |
| num_warps=num_warps, num_stages=num_stages, | |
| ) | |
| return out | |
| def run_preload_only(x: torch.Tensor, *, BM=64, BN=64, num_warps=4, num_stages=2) -> torch.Tensor: | |
| M, N, K = x.shape | |
| out = torch.empty((M, N), dtype=x.dtype, device=x.device) | |
| preload_no_reduce_kernel[_grid(M, N, BM, BN)]( | |
| x, out, M, N, K, | |
| BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=K, | |
| num_warps=num_warps, num_stages=num_stages, | |
| ) | |
| return out | |
| # ------------------------- | |
| # Timing helpers | |
| # ------------------------- | |
| def time_kernel(fn: Callable[[], None], runs: int = 300) -> float: | |
| torch.cuda.synchronize() | |
| # Warmup | |
| for _ in range(50): | |
| fn() | |
| torch.cuda.synchronize() | |
| # Measure | |
| start = time.perf_counter() | |
| for _ in range(runs): | |
| fn() | |
| torch.cuda.synchronize() | |
| end = time.perf_counter() | |
| return (end - start) * 1000.0 / runs | |
| def format_table(headers: List[str], rows: List[List[Any]]) -> str: | |
| cols = list(zip(*([headers] + rows))) | |
| widths = [max(len(str(x)) for x in col) for col in cols] | |
| def fmt_row(r): | |
| return " | ".join(str(v).rjust(w) for v, w in zip(r, widths)) | |
| lines = [fmt_row(headers), "-+-".join("-"*w for w in widths)] | |
| lines += [fmt_row(row) for row in rows] | |
| return "\n".join(lines) | |
| # ------------------------- | |
| # Runners | |
| # ------------------------- | |
| def run_repro(M=256, N=256, K=16, dtype=torch.int64) -> None: | |
| print("=" * 80) | |
| print("3D->2D Bitwise-OR Reduction (Repro)") | |
| print("=" * 80) | |
| print(env_info()) | |
| print(f"Config: M={M}, N={N}, K={K}") | |
| torch.manual_seed(42) | |
| x = torch.randint(0, 2**16, (M, N, K), dtype=dtype, device='cuda') | |
| # Prepare wrappers | |
| std = lambda: run_standard_reduce(x) | |
| man = lambda: run_manual_unroll(x) | |
| alloc = lambda: run_allocate_only(x) | |
| pre = lambda: run_preload_only(x) | |
| # Run once to materialize kernels | |
| out_std = std() | |
| out_man = man() | |
| torch.cuda.synchronize() | |
| # Correctness | |
| if not torch.equal(out_std.cpu(), out_man.cpu()): | |
| raise RuntimeError("Mismatch between tl.reduce and manual unroll outputs") | |
| # Timings | |
| t_std = time_kernel(lambda: std()) | |
| t_man = time_kernel(lambda: man()) | |
| t_alloc = time_kernel(lambda: alloc()) | |
| t_pre = time_kernel(lambda: pre()) | |
| rows = [ | |
| ["tl.reduce", f"{t_std:.5f}", "1.00x"], | |
| ["manual (unroll)", f"{t_man:.5f}", f"{t_std/t_man:.2f}x"], | |
| ["allocate-only", f"{t_alloc:.5f}", f"{t_std/t_alloc:.2f}x"], | |
| ["preload-no-reduce", f"{t_pre:.5f}", f"{t_std/t_pre:.2f}x"], | |
| ] | |
| print("\nSpeed Comparison (ms)") | |
| print(format_table(["Variant", "Time", "Speedup vs tl.reduce"], rows)) | |
| def run_bench(tests: List[Tuple[int, int, int]] = None) -> None: | |
| if tests is None: | |
| tests = [ | |
| (64, 128, 4), (128, 256, 8), (256, 512, 16), | |
| (512, 1024, 4), (1024, 2048, 8), (2048, 4096, 16), | |
| ] | |
| print("\n" + "=" * 80) | |
| print("Benchmark: tl.reduce vs manual unroll") | |
| print("=" * 80) | |
| rows = [] | |
| for M, N, K in tests: | |
| torch.manual_seed(0) | |
| x = torch.randint(0, 2**16, (M, N, K), dtype=torch.int64, device='cuda') | |
| std = lambda: run_standard_reduce(x) | |
| man = lambda: run_manual_unroll(x) | |
| # Prime | |
| std(); man(); torch.cuda.synchronize() | |
| elems = M * N * K | |
| runs = 100 if elems <= 50_000_000 else 10 | |
| t_std = time_kernel(lambda: std(), runs=runs) | |
| t_man = time_kernel(lambda: man(), runs=runs) | |
| rows.append([M, N, K, f"{t_std:.3f}", f"{t_man:.3f}", f"{t_std/t_man:.2f}x"]) | |
| print(format_table(["M", "N", "K", "tl.reduce (ms)", "manual (ms)", "speedup"], rows)) | |
| def main(): | |
| p = argparse.ArgumentParser(description="3D->2D bitwise-OR reduction (Triton)") | |
| p.add_argument("mode", choices=["repro", "bench", "test"], nargs="?", default="repro") | |
| p.add_argument("--M", type=int, default=256) | |
| p.add_argument("--N", type=int, default=256) | |
| p.add_argument("--K", type=int, default=16) | |
| p.add_argument("--shapes", type=str, default="", | |
| help="Semicolon-separated list of M,N,K triples, e.g. '2048,2048,16;4096,4096,16'") | |
| args = p.parse_args() | |
| if not torch.cuda.is_available(): | |
| raise SystemExit("CUDA is not available") | |
| if args.mode == "repro": | |
| run_repro(M=args.M, N=args.N, K=args.K) | |
| elif args.mode == "bench": | |
| tests = None | |
| if args.shapes: | |
| tests = [] | |
| for part in args.shapes.split(";"): | |
| part = part.strip() | |
| if not part: | |
| continue | |
| m, n, k = map(int, part.split(",")) | |
| tests.append((m, n, k)) | |
| run_bench(tests) | |
| else: | |
| # Comprehensive correctness tests | |
| print("=" * 80) | |
| print("Correctness Tests (tl.reduce vs manual unroll)") | |
| print("=" * 80) | |
| ok = True | |
| def cpu_ref(t: torch.Tensor) -> torch.Tensor: | |
| import numpy as np | |
| return torch.from_numpy(np.bitwise_or.reduce(t.cpu().numpy(), axis=-1)).to(t.device) | |
| # Basic shapes | |
| basic_shapes = [ | |
| (32, 64, 4), | |
| (64, 128, 8), | |
| (128, 256, 16), | |
| (1, 1, 4), | |
| (7, 13, 8), | |
| ] | |
| for M, N, K in basic_shapes: | |
| torch.manual_seed(0) | |
| x = torch.randint(0, 2**16, (M, N, K), dtype=torch.int64, device='cuda') | |
| a = run_standard_reduce(x) | |
| b = run_manual_unroll(x) | |
| c = cpu_ref(x) | |
| same_ab = torch.equal(a.cpu(), b.cpu()) | |
| same_ac = torch.equal(a.cpu(), c.cpu()) | |
| print(f" Basic M={M} N={N} K={K}: std==manual={same_ab}, std==cpu={same_ac}") | |
| ok &= same_ab and same_ac | |
| # Data types | |
| dtypes = [torch.int32, torch.int64, torch.uint8] | |
| for dt in dtypes: | |
| K = 8 | |
| torch.manual_seed(1) | |
| high = 256 if dt is torch.uint8 else 2**16 | |
| x = torch.randint(0, high, (32, 32, K), dtype=dt, device='cuda') | |
| a = run_standard_reduce(x) | |
| b = run_manual_unroll(x) | |
| c = cpu_ref(x) | |
| same_ab = torch.equal(a.cpu(), b.cpu()) | |
| same_ac = torch.equal(a.cpu(), c.cpu()) | |
| print(f" DType {str(dt).split('.')[-1]}: std==manual={same_ab}, std==cpu={same_ac}") | |
| ok &= same_ab and same_ac | |
| # Bitwise properties | |
| K = 4 | |
| torch.manual_seed(2) | |
| base = torch.randint(0, 2**8, (16, 16, 1), dtype=torch.int64, device='cuda') | |
| x = base.repeat(1, 1, K) | |
| a = run_standard_reduce(x) | |
| b = run_manual_unroll(x) | |
| expected = base.squeeze(-1) | |
| idem_ok = torch.equal(a.cpu(), expected.cpu()) and torch.equal(b.cpu(), expected.cpu()) | |
| print(f" Idempotence x|x=x: {idem_ok}") | |
| ok &= idem_ok | |
| torch.manual_seed(3) | |
| y = torch.randint(0, 2**8, (8, 8, K), dtype=torch.int64, device='cuda') | |
| # Shuffle K per element | |
| ys = y.clone() | |
| for i in range(y.shape[0]): | |
| for j in range(y.shape[1]): | |
| perm = torch.randperm(K) | |
| ys[i, j, :] = y[i, j, perm] | |
| ao = run_manual_unroll(y) | |
| bo = run_manual_unroll(ys) | |
| comm_ok = torch.equal(ao.cpu(), bo.cpu()) | |
| print(f" Commutativity (shuffle K): {comm_ok}") | |
| ok &= comm_ok | |
| print("\nRESULT:", "PASS" if ok else "FAIL") | |
| if not ok: | |
| raise SystemExit(1) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment