Skip to content

Instantly share code, notes, and snippets.

@juvi21
Created September 5, 2025 08:51
Show Gist options
  • Select an option

  • Save juvi21/2f3fbe8cb0809adeb3f977894bb36812 to your computer and use it in GitHub Desktop.

Select an option

Save juvi21/2f3fbe8cb0809adeb3f977894bb36812 to your computer and use it in GitHub Desktop.
import time
import math
import argparse
from typing import Callable, Dict, Any, Tuple, List
import torch
import triton
import triton.language as tl
# -------------------------
# Environment helpers
# -------------------------
def env_info() -> str:
try:
dev = torch.cuda.get_device_name(0)
cap = torch.cuda.get_device_capability(0)
cuda_ver = getattr(torch.version, 'cuda', 'unknown')
except Exception:
dev, cap, cuda_ver = "<unknown>", (0, 0), 'unknown'
return (
f"Torch {torch.__version__}, CUDA {cuda_ver}, "
f"Triton {triton.__version__}, GPU {dev}, CC {cap}"
)
# -------------------------
# Kernels
# -------------------------
@triton.jit
def u64_or(a, b):
return a | b
@triton.jit
def standard_reduce_kernel(
input_ptr, output_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
k = tl.arange(0, BLOCK_K)
offs = (m[:, None, None] * N * K +
n[None, :, None] * K +
k[None, None, :])
mask3 = ((m[:, None, None] < M) &
(n[None, :, None] < N) &
(k[None, None, :] < K))
tile = tl.load(input_ptr + offs, mask=mask3, other=0)
out = tl.reduce(tile, axis=2, combine_fn=u64_or)
out_offs = m[:, None] * N + n[None, :]
out_mask = (m[:, None] < M) & (n[None, :] < N)
tl.store(output_ptr + out_offs, out, mask=out_mask)
@triton.jit
def allocate_only_sram_kernel(
input_ptr, output_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# Allocate a local 3D tile to mirror the reduction tile shape
tmp = tl.zeros((BLOCK_M, BLOCK_N, BLOCK_K), dtype=tl.int64)
# Consume to avoid DCE; result remains zero
res = tl.sum(tmp, axis=2)
out_offs = m[:, None] * N + n[None, :]
out_mask = (m[:, None] < M) & (n[None, :] < N)
tl.store(output_ptr + out_offs, res, mask=out_mask)
@triton.jit
def preload_no_reduce_kernel(
input_ptr, output_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
k = tl.arange(0, BLOCK_K)
offs = (m[:, None, None] * N * K +
n[None, :, None] * K +
k[None, None, :])
mask3 = ((m[:, None, None] < M) &
(n[None, :, None] < N) &
(k[None, None, :] < K))
tile = tl.load(input_ptr + offs, mask=mask3, other=0)
# Produce output from k=0 via direct load (avoid local slicing issues)
m2 = m[:, None]
n2 = n[None, :]
mask2 = (m2 < M) & (n2 < N)
offs2 = m2 * N * K + n2 * K + 0
res = tl.load(input_ptr + offs2, mask=mask2, other=0)
out_offs = m2 * N + n2
tl.store(output_ptr + out_offs, res, mask=mask2)
def build_manual_unroll_kernel(K: int) -> Callable:
"""Dynamically build a manual-unrolled kernel (global loads) for a given K."""
if K < 1:
raise ValueError("K must be >= 1")
lines = ["res = tl.load(input_ptr + input_base + 0, mask=mask, other=0)"]
for i in range(1, K):
lines.append(
f"res = res | tl.load(input_ptr + input_base + {i}, mask=mask, other=0)"
)
body = " " + "\n ".join(lines)
src = f"""
@triton.jit
def manual_unroll_k{K}_kernel(
input_ptr, output_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
m2 = m[:, None]
n2 = n[None, :]
mask = (m2 < M) & (n2 < N)
input_base = m2 * N * K + n2 * K
# Manual unroll (global loads)
{body}
out_offs = m2 * N + n2
tl.store(output_ptr + out_offs, res, mask=mask)
"""
# Write to a temp file so Triton can retrieve source for JIT
import importlib.util, sys, os
mod_name = f"manual_unroll_k{K}_kernel_mod"
file_path = os.path.abspath(f"_manual_unroll_k{K}_kernel.py")
with open(file_path, "w") as f:
f.write("import triton\nimport triton.language as tl\n\n")
f.write(src)
spec = importlib.util.spec_from_file_location(mod_name, file_path)
mod = importlib.util.module_from_spec(spec)
sys.modules[mod_name] = mod
assert spec.loader is not None
spec.loader.exec_module(mod)
return getattr(mod, f"manual_unroll_k{K}_kernel")
# -------------------------
# Wrappers
# -------------------------
def _grid(M: int, N: int, BM: int, BN: int) -> Tuple[int, int]:
return (triton.cdiv(M, BM), triton.cdiv(N, BN))
def run_standard_reduce(x: torch.Tensor, *, BM=64, BN=64, num_warps=4, num_stages=2) -> torch.Tensor:
M, N, K = x.shape
out = torch.empty((M, N), dtype=x.dtype, device=x.device)
standard_reduce_kernel[_grid(M, N, BM, BN)](
x, out, M, N, K,
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=K,
num_warps=num_warps, num_stages=num_stages,
)
return out
_MANUAL_KERNEL_CACHE: Dict[int, Callable] = {}
def run_manual_unroll(x: torch.Tensor, *, BM=64, BN=64, num_warps=4, num_stages=2) -> torch.Tensor:
M, N, K = x.shape
assert int(K) == K, "K must be an integer"
k_int = int(K)
kern = _MANUAL_KERNEL_CACHE.get(k_int)
if kern is None:
kern = build_manual_unroll_kernel(k_int)
_MANUAL_KERNEL_CACHE[k_int] = kern
out = torch.empty((M, N), dtype=x.dtype, device=x.device)
kern[_grid(M, N, BM, BN)](
x, out, M, N, K,
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=K,
num_warps=num_warps, num_stages=num_stages,
)
return out
def run_allocate_only(x: torch.Tensor, *, BM=64, BN=64, num_warps=4, num_stages=2) -> torch.Tensor:
M, N, K = x.shape
out = torch.empty((M, N), dtype=x.dtype, device=x.device)
allocate_only_sram_kernel[_grid(M, N, BM, BN)](
x, out, M, N, K,
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=K,
num_warps=num_warps, num_stages=num_stages,
)
return out
def run_preload_only(x: torch.Tensor, *, BM=64, BN=64, num_warps=4, num_stages=2) -> torch.Tensor:
M, N, K = x.shape
out = torch.empty((M, N), dtype=x.dtype, device=x.device)
preload_no_reduce_kernel[_grid(M, N, BM, BN)](
x, out, M, N, K,
BLOCK_M=BM, BLOCK_N=BN, BLOCK_K=K,
num_warps=num_warps, num_stages=num_stages,
)
return out
# -------------------------
# Timing helpers
# -------------------------
def time_kernel(fn: Callable[[], None], runs: int = 300) -> float:
torch.cuda.synchronize()
# Warmup
for _ in range(50):
fn()
torch.cuda.synchronize()
# Measure
start = time.perf_counter()
for _ in range(runs):
fn()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) * 1000.0 / runs
def format_table(headers: List[str], rows: List[List[Any]]) -> str:
cols = list(zip(*([headers] + rows)))
widths = [max(len(str(x)) for x in col) for col in cols]
def fmt_row(r):
return " | ".join(str(v).rjust(w) for v, w in zip(r, widths))
lines = [fmt_row(headers), "-+-".join("-"*w for w in widths)]
lines += [fmt_row(row) for row in rows]
return "\n".join(lines)
# -------------------------
# Runners
# -------------------------
def run_repro(M=256, N=256, K=16, dtype=torch.int64) -> None:
print("=" * 80)
print("3D->2D Bitwise-OR Reduction (Repro)")
print("=" * 80)
print(env_info())
print(f"Config: M={M}, N={N}, K={K}")
torch.manual_seed(42)
x = torch.randint(0, 2**16, (M, N, K), dtype=dtype, device='cuda')
# Prepare wrappers
std = lambda: run_standard_reduce(x)
man = lambda: run_manual_unroll(x)
alloc = lambda: run_allocate_only(x)
pre = lambda: run_preload_only(x)
# Run once to materialize kernels
out_std = std()
out_man = man()
torch.cuda.synchronize()
# Correctness
if not torch.equal(out_std.cpu(), out_man.cpu()):
raise RuntimeError("Mismatch between tl.reduce and manual unroll outputs")
# Timings
t_std = time_kernel(lambda: std())
t_man = time_kernel(lambda: man())
t_alloc = time_kernel(lambda: alloc())
t_pre = time_kernel(lambda: pre())
rows = [
["tl.reduce", f"{t_std:.5f}", "1.00x"],
["manual (unroll)", f"{t_man:.5f}", f"{t_std/t_man:.2f}x"],
["allocate-only", f"{t_alloc:.5f}", f"{t_std/t_alloc:.2f}x"],
["preload-no-reduce", f"{t_pre:.5f}", f"{t_std/t_pre:.2f}x"],
]
print("\nSpeed Comparison (ms)")
print(format_table(["Variant", "Time", "Speedup vs tl.reduce"], rows))
def run_bench(tests: List[Tuple[int, int, int]] = None) -> None:
if tests is None:
tests = [
(64, 128, 4), (128, 256, 8), (256, 512, 16),
(512, 1024, 4), (1024, 2048, 8), (2048, 4096, 16),
]
print("\n" + "=" * 80)
print("Benchmark: tl.reduce vs manual unroll")
print("=" * 80)
rows = []
for M, N, K in tests:
torch.manual_seed(0)
x = torch.randint(0, 2**16, (M, N, K), dtype=torch.int64, device='cuda')
std = lambda: run_standard_reduce(x)
man = lambda: run_manual_unroll(x)
# Prime
std(); man(); torch.cuda.synchronize()
elems = M * N * K
runs = 100 if elems <= 50_000_000 else 10
t_std = time_kernel(lambda: std(), runs=runs)
t_man = time_kernel(lambda: man(), runs=runs)
rows.append([M, N, K, f"{t_std:.3f}", f"{t_man:.3f}", f"{t_std/t_man:.2f}x"])
print(format_table(["M", "N", "K", "tl.reduce (ms)", "manual (ms)", "speedup"], rows))
def main():
p = argparse.ArgumentParser(description="3D->2D bitwise-OR reduction (Triton)")
p.add_argument("mode", choices=["repro", "bench", "test"], nargs="?", default="repro")
p.add_argument("--M", type=int, default=256)
p.add_argument("--N", type=int, default=256)
p.add_argument("--K", type=int, default=16)
p.add_argument("--shapes", type=str, default="",
help="Semicolon-separated list of M,N,K triples, e.g. '2048,2048,16;4096,4096,16'")
args = p.parse_args()
if not torch.cuda.is_available():
raise SystemExit("CUDA is not available")
if args.mode == "repro":
run_repro(M=args.M, N=args.N, K=args.K)
elif args.mode == "bench":
tests = None
if args.shapes:
tests = []
for part in args.shapes.split(";"):
part = part.strip()
if not part:
continue
m, n, k = map(int, part.split(","))
tests.append((m, n, k))
run_bench(tests)
else:
# Comprehensive correctness tests
print("=" * 80)
print("Correctness Tests (tl.reduce vs manual unroll)")
print("=" * 80)
ok = True
def cpu_ref(t: torch.Tensor) -> torch.Tensor:
import numpy as np
return torch.from_numpy(np.bitwise_or.reduce(t.cpu().numpy(), axis=-1)).to(t.device)
# Basic shapes
basic_shapes = [
(32, 64, 4),
(64, 128, 8),
(128, 256, 16),
(1, 1, 4),
(7, 13, 8),
]
for M, N, K in basic_shapes:
torch.manual_seed(0)
x = torch.randint(0, 2**16, (M, N, K), dtype=torch.int64, device='cuda')
a = run_standard_reduce(x)
b = run_manual_unroll(x)
c = cpu_ref(x)
same_ab = torch.equal(a.cpu(), b.cpu())
same_ac = torch.equal(a.cpu(), c.cpu())
print(f" Basic M={M} N={N} K={K}: std==manual={same_ab}, std==cpu={same_ac}")
ok &= same_ab and same_ac
# Data types
dtypes = [torch.int32, torch.int64, torch.uint8]
for dt in dtypes:
K = 8
torch.manual_seed(1)
high = 256 if dt is torch.uint8 else 2**16
x = torch.randint(0, high, (32, 32, K), dtype=dt, device='cuda')
a = run_standard_reduce(x)
b = run_manual_unroll(x)
c = cpu_ref(x)
same_ab = torch.equal(a.cpu(), b.cpu())
same_ac = torch.equal(a.cpu(), c.cpu())
print(f" DType {str(dt).split('.')[-1]}: std==manual={same_ab}, std==cpu={same_ac}")
ok &= same_ab and same_ac
# Bitwise properties
K = 4
torch.manual_seed(2)
base = torch.randint(0, 2**8, (16, 16, 1), dtype=torch.int64, device='cuda')
x = base.repeat(1, 1, K)
a = run_standard_reduce(x)
b = run_manual_unroll(x)
expected = base.squeeze(-1)
idem_ok = torch.equal(a.cpu(), expected.cpu()) and torch.equal(b.cpu(), expected.cpu())
print(f" Idempotence x|x=x: {idem_ok}")
ok &= idem_ok
torch.manual_seed(3)
y = torch.randint(0, 2**8, (8, 8, K), dtype=torch.int64, device='cuda')
# Shuffle K per element
ys = y.clone()
for i in range(y.shape[0]):
for j in range(y.shape[1]):
perm = torch.randperm(K)
ys[i, j, :] = y[i, j, perm]
ao = run_manual_unroll(y)
bo = run_manual_unroll(ys)
comm_ok = torch.equal(ao.cpu(), bo.cpu())
print(f" Commutativity (shuffle K): {comm_ok}")
ok &= comm_ok
print("\nRESULT:", "PASS" if ok else "FAIL")
if not ok:
raise SystemExit(1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment