Skip to content

Instantly share code, notes, and snippets.

@Codys12
Created December 3, 2025 21:18
Show Gist options
  • Select an option

  • Save Codys12/588e4128076c1366f27f7a6af00954e7 to your computer and use it in GitHub Desktop.

Select an option

Save Codys12/588e4128076c1366f27f7a6af00954e7 to your computer and use it in GitHub Desktop.
import math
import random
import re
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
@dataclass(frozen=True)
class MatmulSchedule:
"""Immutable representation of a blocked matmul schedule."""
tiles: Dict[str, Tuple[int, ...]]
order: Tuple[str, ...]
def to_string(self) -> str:
parts = []
for axis in ("i", "j", "k"):
sizes = ",".join(str(x) for x in self.tiles[axis])
parts.append(f"tile({axis},{sizes})")
parts.append("order(" + ",".join(self.order) + ")")
return "; ".join(parts)
class MatmulDSL:
"""
A tiny DSL for exploring tiling/subtiling and loop reordering for:
C[M,N] += A[M,K] @ B[K,N]
Schedule syntax (whitespace ignored):
tile(i,<outer>,<inner>,...); tile(j,...); tile(k,...); order(i0,j0,k0,i1,j1,k1,...)
Semantics:
- Each axis has hierarchical tile sizes; constraints:
extent % tile[0] == 0
tile[l-1] % tile[l] == 0 for l>0
- Each tile level adds a loop var: i0,i1,... j0,j1,... k0,k1,...
- order(...) must contain exactly all loop vars once, and preserve per-axis nesting:
i0 before i1 before i2..., similarly for j and k.
- The innermost tile sizes define the micro-op:
C[i:i+ti, j:j+tj] += A[i:i+ti, k:k+tk] @ B[k:k+tk, j:j+tj]
"""
AXES = ("i", "j", "k")
def __init__(self, m: int = 1024, n: int = 1024, k: int = 1024, schedule: Optional[str] = None):
self.m = int(m)
self.n = int(n)
self.k = int(k)
self.extents = {"i": self.m, "j": self.n, "k": self.k}
self._kernel_cache: Dict[Tuple[str, bool, bool], Any] = {}
self.schedule = self.parse(schedule) if schedule is not None else self.parse(self.default_schedule())
@staticmethod
def _largest_divisor_leq(x: int, target: int) -> int:
"""Prefer powers-of-two divisors; fallback to any divisor <= target."""
target = min(target, x)
p = 1
while p * 2 <= target:
p *= 2
while p >= 1:
if x % p == 0:
return p
p //= 2
for d in range(target, 0, -1):
if x % d == 0:
return d
return 1
def default_schedule(self) -> str:
def chain(ext: int, outer_target: int, inner_target: int) -> Tuple[int, ...]:
outer = self._largest_divisor_leq(ext, outer_target)
inner = self._largest_divisor_leq(outer, min(inner_target, outer))
return (outer,) if inner == outer else (outer, inner)
ti = chain(self.m, 256, 64)
tj = chain(self.n, 256, 64)
tk = chain(self.k, 256, 64)
tiles = {"i": list(ti), "j": list(tj), "k": list(tk)}
order = self._default_order_from_tiles(tiles)
return MatmulSchedule(tiles={"i": ti, "j": tj, "k": tk}, order=tuple(order)).to_string()
@staticmethod
def _strip_ws(s: str) -> str:
return re.sub(r"\s+", "", s)
def _default_order_from_tiles(self, tiles: Dict[str, List[int]]) -> List[str]:
max_levels = max(len(tiles[a]) for a in self.AXES)
out = []
for lvl in range(max_levels):
for a in self.AXES:
if lvl < len(tiles[a]):
out.append(f"{a}{lvl}")
return out
def _validate(self, sched: MatmulSchedule) -> None:
for axis in self.AXES:
sizes = list(sched.tiles[axis])
if not sizes:
raise ValueError(f"Axis {axis}: missing tiles")
if any((not isinstance(x, int)) or x <= 0 for x in sizes):
raise ValueError(f"Axis {axis}: non-positive tile in {sizes}")
ext = self.extents[axis]
if ext % sizes[0] != 0:
raise ValueError(f"Axis {axis}: outer tile {sizes[0]} does not divide extent {ext}")
for l in range(1, len(sizes)):
if sizes[l - 1] % sizes[l] != 0:
raise ValueError(f"Axis {axis}: tile {sizes[l]} must divide parent tile {sizes[l - 1]}")
expected_loops = []
for axis in self.AXES:
expected_loops.extend([f"{axis}{l}" for l in range(len(sched.tiles[axis]))])
if sorted(sched.order) != sorted(expected_loops):
raise ValueError(f"Order must contain exactly loops {expected_loops}, got {list(sched.order)}")
pos = {name: i for i, name in enumerate(sched.order)}
for axis in self.AXES:
for l in range(1, len(sched.tiles[axis])):
if pos[f"{axis}{l - 1}"] > pos[f"{axis}{l}"]:
raise ValueError(f"Invalid order: {axis}{l - 1} must appear before {axis}{l}")
def parse(self, schedule: str) -> MatmulSchedule:
if not isinstance(schedule, str) or not schedule.strip():
raise ValueError("schedule must be a non-empty string")
s = self._strip_ws(schedule)
stmts = [p for p in s.split(";") if p]
tiles: Dict[str, List[int]] = {}
order: Optional[List[str]] = None
tile_re = re.compile(r"^tile\(([ijk]),([0-9,]+)\)$")
order_re = re.compile(r"^order\(([a-z0-9,]+)\)$")
for stmt in stmts:
m_tile = tile_re.match(stmt)
if m_tile:
axis = m_tile.group(1)
sizes = [int(x) for x in m_tile.group(2).split(",") if x]
if not sizes:
raise ValueError(f"Empty tile list for axis {axis}")
tiles[axis] = sizes
continue
m_order = order_re.match(stmt)
if m_order:
order = [x for x in m_order.group(1).split(",") if x]
continue
raise ValueError(f"Unrecognized statement: {stmt!r}")
for axis in self.AXES:
if axis not in tiles:
tiles[axis] = [self.extents[axis]]
if order is None:
order = self._default_order_from_tiles(tiles)
sched = MatmulSchedule(tiles={a: tuple(tiles[a]) for a in self.AXES}, order=tuple(order))
self._validate(sched)
return sched
def canonicalize(self, schedule: str) -> str:
return self.parse(schedule).to_string()
def compile(
self,
schedule: Optional[str] = None,
*,
return_source: bool = False,
debug_checks: bool = True,
use_matmul_out: bool = True,
) -> Any:
"""
Compiles the schedule to a Python kernel: kernel(A, B, C).
Uses codegen + exec (safe because parsing/validation restricts tokens).
"""
sched = self.parse(schedule) if schedule is not None else self.schedule
key = (sched.to_string(), debug_checks, use_matmul_out)
if not return_source and key in self._kernel_cache:
return self._kernel_cache[key]
tiles = {a: list(sched.tiles[a]) for a in self.AXES}
ext = self.extents
innermost_var = {a: f"{a}{len(tiles[a]) - 1}" for a in self.AXES}
ti, tj, tk = tiles["i"][-1], tiles["j"][-1], tiles["k"][-1]
def loop_header(loop_name: str) -> str:
axis = loop_name[0]
lvl = int(loop_name[1:])
if lvl == 0:
return f"for {loop_name} in range(0, {ext[axis]}, {tiles[axis][0]}):"
parent = f"{axis}{lvl - 1}"
return f"for {loop_name} in range({parent}, {parent}+{tiles[axis][lvl - 1]}, {tiles[axis][lvl]}):"
lines = ["def kernel(A, B, C):"]
indent = " "
if debug_checks:
lines.append(indent + f"assert A.shape == ({ext['i']}, {ext['k']})")
lines.append(indent + f"assert B.shape == ({ext['k']}, {ext['j']})")
lines.append(indent + f"assert C.shape == ({ext['i']}, {ext['j']})")
if use_matmul_out:
lines.append(indent + f"tmp = np.empty(({ti}, {tj}), dtype=C.dtype)")
cur = indent
for ln in sched.order:
lines.append(cur + loop_header(ln))
cur += indent
iv, jv, kv = innermost_var["i"], innermost_var["j"], innermost_var["k"]
lines.append(cur + f"i = {iv}; j = {jv}; k = {kv}")
if use_matmul_out:
lines.append(cur + f"np.matmul(A[i:i+{ti}, k:k+{tk}], B[k:k+{tk}, j:j+{tj}], out=tmp)")
lines.append(cur + f"C[i:i+{ti}, j:j+{tj}] += tmp")
else:
lines.append(cur + f"C[i:i+{ti}, j:j+{tj}] += A[i:i+{ti}, k:k+{tk}] @ B[k:k+{tk}, j:j+{tj}]")
src = "\n".join(lines)
loc: Dict[str, Any] = {}
exec(src, {"np": np}, loc)
kernel = loc["kernel"]
if not return_source:
self._kernel_cache[key] = kernel
return kernel
return kernel, src
@staticmethod
def _randn(rng: np.random.Generator, shape: Tuple[int, ...], dtype) -> np.ndarray:
try:
return rng.standard_normal(shape, dtype=dtype)
except TypeError:
return rng.standard_normal(shape).astype(dtype, copy=False)
def run(
self,
schedule: Optional[str] = None,
*,
A: Optional[np.ndarray] = None,
B: Optional[np.ndarray] = None,
dtype=np.float32,
seed: int = 0,
debug_checks: bool = False,
use_matmul_out: bool = True,
) -> np.ndarray:
"""Compute C for the DSL's (m,n,k) using the given schedule."""
sched = self.parse(schedule) if schedule is not None else self.schedule
kernel = self.compile(sched.to_string(), debug_checks=debug_checks, use_matmul_out=use_matmul_out)
rng = np.random.default_rng(seed)
if A is None:
A = self._randn(rng, (self.m, self.k), dtype=dtype)
if B is None:
B = self._randn(rng, (self.k, self.n), dtype=dtype)
C = np.zeros((self.m, self.n), dtype=dtype)
kernel(A, B, C)
return C
def evaluate(
self,
schedule: Optional[str] = None,
*,
dtype=np.float32,
seed: int = 0,
n_runs: int = 3,
warmup: int = 1,
debug_checks: bool = False,
use_matmul_out: bool = True,
check_correctness: bool = False,
correctness_atol: float = 1e-3,
correctness_rtol: float = 1e-3,
) -> Dict[str, float]:
"""Time the scheduled matmul on random inputs and return seconds + GFLOP/s."""
sched = self.parse(schedule) if schedule is not None else self.schedule
kernel = self.compile(sched.to_string(), debug_checks=debug_checks, use_matmul_out=use_matmul_out)
rng = np.random.default_rng(seed)
A = self._randn(rng, (self.m, self.k), dtype=dtype)
B = self._randn(rng, (self.k, self.n), dtype=dtype)
C = np.zeros((self.m, self.n), dtype=dtype)
for _ in range(max(0, warmup)):
C.fill(0)
kernel(A, B, C)
times = []
for _ in range(max(1, n_runs)):
C.fill(0)
t0 = time.perf_counter()
kernel(A, B, C)
times.append(time.perf_counter() - t0)
if check_correctness:
C_ref = A @ B
if not np.allclose(C, C_ref, atol=correctness_atol, rtol=correctness_rtol):
err = float(np.max(np.abs(C - C_ref)))
raise AssertionError(f"Correctness check failed. max|err|={err}")
ops = 2.0 * self.m * self.n * self.k
best = min(times)
mean = sum(times) / len(times)
return {
"seconds_best": float(best),
"seconds_mean": float(mean),
"gflops_best": float(ops / best / 1e9),
"gflops_mean": float(ops / mean / 1e9),
}
def mutate(
self,
schedule: Optional[str] = None,
*,
seed: Optional[int] = None,
n_mutations: int = 1,
max_tries: int = 200,
min_tile: int = 4,
max_levels: int = 5,
) -> str:
"""
Randomly mutate a schedule string, returning a NEW valid schedule string.
Mutations include:
- swapping two adjacent loops in order (if still valid)
- splitting the innermost tile (adds a new level)
- fusing the innermost tile (removes a level)
- retile the innermost tile to another divisor of its parent tile
"""
rng = random.Random(seed)
base = self.parse(schedule) if schedule is not None else self.schedule
def divisors(x: int) -> List[int]:
ds = []
for d in range(1, int(math.isqrt(x)) + 1):
if x % d == 0:
ds.append(d)
if d * d != x:
ds.append(x // d)
return sorted(ds)
def remove_loop(order: List[str], loop: str) -> None:
if loop in order:
order.remove(loop)
def insert_loop_after(order: List[str], after: str, loop: str) -> None:
idx = order.index(after) + 1
order.insert(idx, loop)
sched = base
for _ in range(max(1, n_mutations)):
for _attempt in range(max_tries):
tiles = {a: list(sched.tiles[a]) for a in self.AXES}
order = list(sched.order)
can_split_axes = [
a for a in self.AXES
if len(tiles[a]) < max_levels and tiles[a][-1] >= 2 * min_tile
]
can_fuse_axes = [a for a in self.AXES if len(tiles[a]) > 1]
ops, weights = [], []
if len(order) >= 2:
ops.append("swap_adjacent"); weights.append(0.55)
if can_split_axes:
ops.append("split_innermost"); weights.append(0.20)
if can_fuse_axes:
ops.append("fuse_innermost"); weights.append(0.15)
ops.append("retile_innermost"); weights.append(0.10)
op = rng.choices(ops, weights=weights, k=1)[0]
try:
if op == "swap_adjacent":
idx = rng.randrange(0, len(order) - 1)
order[idx], order[idx + 1] = order[idx + 1], order[idx]
elif op == "split_innermost":
axis = rng.choice(can_split_axes)
last = tiles[axis][-1]
ds = [d for d in divisors(last) if min_tile <= d < last]
if not ds:
raise ValueError("cannot split")
new_inner = rng.choice(ds)
tiles[axis].append(new_inner)
new_loop = f"{axis}{len(tiles[axis]) - 1}"
after_loop = f"{axis}{len(tiles[axis]) - 2}"
insert_loop_after(order, after_loop, new_loop)
elif op == "fuse_innermost":
axis = rng.choice(can_fuse_axes)
removed_level = len(tiles[axis]) - 1
tiles[axis].pop()
remove_loop(order, f"{axis}{removed_level}")
elif op == "retile_innermost":
axis = rng.choice(list(self.AXES))
parent = tiles[axis][-2] if len(tiles[axis]) > 1 else self.extents[axis]
ds = [d for d in divisors(parent) if d >= min_tile]
tiles[axis][-1] = rng.choice(ds) if ds else parent
mutated = MatmulSchedule(tiles={a: tuple(tiles[a]) for a in self.AXES}, order=tuple(order))
self._validate(mutated)
sched = mutated
break
except Exception:
continue
else:
raise RuntimeError("Failed to produce a valid mutation (too many rejected tries).")
return sched.to_string()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment