Created
December 3, 2025 21:18
-
-
Save Codys12/588e4128076c1366f27f7a6af00954e7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import random | |
| import re | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| @dataclass(frozen=True) | |
| class MatmulSchedule: | |
| """Immutable representation of a blocked matmul schedule.""" | |
| tiles: Dict[str, Tuple[int, ...]] | |
| order: Tuple[str, ...] | |
| def to_string(self) -> str: | |
| parts = [] | |
| for axis in ("i", "j", "k"): | |
| sizes = ",".join(str(x) for x in self.tiles[axis]) | |
| parts.append(f"tile({axis},{sizes})") | |
| parts.append("order(" + ",".join(self.order) + ")") | |
| return "; ".join(parts) | |
| class MatmulDSL: | |
| """ | |
| A tiny DSL for exploring tiling/subtiling and loop reordering for: | |
| C[M,N] += A[M,K] @ B[K,N] | |
| Schedule syntax (whitespace ignored): | |
| tile(i,<outer>,<inner>,...); tile(j,...); tile(k,...); order(i0,j0,k0,i1,j1,k1,...) | |
| Semantics: | |
| - Each axis has hierarchical tile sizes; constraints: | |
| extent % tile[0] == 0 | |
| tile[l-1] % tile[l] == 0 for l>0 | |
| - Each tile level adds a loop var: i0,i1,... j0,j1,... k0,k1,... | |
| - order(...) must contain exactly all loop vars once, and preserve per-axis nesting: | |
| i0 before i1 before i2..., similarly for j and k. | |
| - The innermost tile sizes define the micro-op: | |
| C[i:i+ti, j:j+tj] += A[i:i+ti, k:k+tk] @ B[k:k+tk, j:j+tj] | |
| """ | |
| AXES = ("i", "j", "k") | |
| def __init__(self, m: int = 1024, n: int = 1024, k: int = 1024, schedule: Optional[str] = None): | |
| self.m = int(m) | |
| self.n = int(n) | |
| self.k = int(k) | |
| self.extents = {"i": self.m, "j": self.n, "k": self.k} | |
| self._kernel_cache: Dict[Tuple[str, bool, bool], Any] = {} | |
| self.schedule = self.parse(schedule) if schedule is not None else self.parse(self.default_schedule()) | |
| @staticmethod | |
| def _largest_divisor_leq(x: int, target: int) -> int: | |
| """Prefer powers-of-two divisors; fallback to any divisor <= target.""" | |
| target = min(target, x) | |
| p = 1 | |
| while p * 2 <= target: | |
| p *= 2 | |
| while p >= 1: | |
| if x % p == 0: | |
| return p | |
| p //= 2 | |
| for d in range(target, 0, -1): | |
| if x % d == 0: | |
| return d | |
| return 1 | |
| def default_schedule(self) -> str: | |
| def chain(ext: int, outer_target: int, inner_target: int) -> Tuple[int, ...]: | |
| outer = self._largest_divisor_leq(ext, outer_target) | |
| inner = self._largest_divisor_leq(outer, min(inner_target, outer)) | |
| return (outer,) if inner == outer else (outer, inner) | |
| ti = chain(self.m, 256, 64) | |
| tj = chain(self.n, 256, 64) | |
| tk = chain(self.k, 256, 64) | |
| tiles = {"i": list(ti), "j": list(tj), "k": list(tk)} | |
| order = self._default_order_from_tiles(tiles) | |
| return MatmulSchedule(tiles={"i": ti, "j": tj, "k": tk}, order=tuple(order)).to_string() | |
| @staticmethod | |
| def _strip_ws(s: str) -> str: | |
| return re.sub(r"\s+", "", s) | |
| def _default_order_from_tiles(self, tiles: Dict[str, List[int]]) -> List[str]: | |
| max_levels = max(len(tiles[a]) for a in self.AXES) | |
| out = [] | |
| for lvl in range(max_levels): | |
| for a in self.AXES: | |
| if lvl < len(tiles[a]): | |
| out.append(f"{a}{lvl}") | |
| return out | |
| def _validate(self, sched: MatmulSchedule) -> None: | |
| for axis in self.AXES: | |
| sizes = list(sched.tiles[axis]) | |
| if not sizes: | |
| raise ValueError(f"Axis {axis}: missing tiles") | |
| if any((not isinstance(x, int)) or x <= 0 for x in sizes): | |
| raise ValueError(f"Axis {axis}: non-positive tile in {sizes}") | |
| ext = self.extents[axis] | |
| if ext % sizes[0] != 0: | |
| raise ValueError(f"Axis {axis}: outer tile {sizes[0]} does not divide extent {ext}") | |
| for l in range(1, len(sizes)): | |
| if sizes[l - 1] % sizes[l] != 0: | |
| raise ValueError(f"Axis {axis}: tile {sizes[l]} must divide parent tile {sizes[l - 1]}") | |
| expected_loops = [] | |
| for axis in self.AXES: | |
| expected_loops.extend([f"{axis}{l}" for l in range(len(sched.tiles[axis]))]) | |
| if sorted(sched.order) != sorted(expected_loops): | |
| raise ValueError(f"Order must contain exactly loops {expected_loops}, got {list(sched.order)}") | |
| pos = {name: i for i, name in enumerate(sched.order)} | |
| for axis in self.AXES: | |
| for l in range(1, len(sched.tiles[axis])): | |
| if pos[f"{axis}{l - 1}"] > pos[f"{axis}{l}"]: | |
| raise ValueError(f"Invalid order: {axis}{l - 1} must appear before {axis}{l}") | |
| def parse(self, schedule: str) -> MatmulSchedule: | |
| if not isinstance(schedule, str) or not schedule.strip(): | |
| raise ValueError("schedule must be a non-empty string") | |
| s = self._strip_ws(schedule) | |
| stmts = [p for p in s.split(";") if p] | |
| tiles: Dict[str, List[int]] = {} | |
| order: Optional[List[str]] = None | |
| tile_re = re.compile(r"^tile\(([ijk]),([0-9,]+)\)$") | |
| order_re = re.compile(r"^order\(([a-z0-9,]+)\)$") | |
| for stmt in stmts: | |
| m_tile = tile_re.match(stmt) | |
| if m_tile: | |
| axis = m_tile.group(1) | |
| sizes = [int(x) for x in m_tile.group(2).split(",") if x] | |
| if not sizes: | |
| raise ValueError(f"Empty tile list for axis {axis}") | |
| tiles[axis] = sizes | |
| continue | |
| m_order = order_re.match(stmt) | |
| if m_order: | |
| order = [x for x in m_order.group(1).split(",") if x] | |
| continue | |
| raise ValueError(f"Unrecognized statement: {stmt!r}") | |
| for axis in self.AXES: | |
| if axis not in tiles: | |
| tiles[axis] = [self.extents[axis]] | |
| if order is None: | |
| order = self._default_order_from_tiles(tiles) | |
| sched = MatmulSchedule(tiles={a: tuple(tiles[a]) for a in self.AXES}, order=tuple(order)) | |
| self._validate(sched) | |
| return sched | |
| def canonicalize(self, schedule: str) -> str: | |
| return self.parse(schedule).to_string() | |
| def compile( | |
| self, | |
| schedule: Optional[str] = None, | |
| *, | |
| return_source: bool = False, | |
| debug_checks: bool = True, | |
| use_matmul_out: bool = True, | |
| ) -> Any: | |
| """ | |
| Compiles the schedule to a Python kernel: kernel(A, B, C). | |
| Uses codegen + exec (safe because parsing/validation restricts tokens). | |
| """ | |
| sched = self.parse(schedule) if schedule is not None else self.schedule | |
| key = (sched.to_string(), debug_checks, use_matmul_out) | |
| if not return_source and key in self._kernel_cache: | |
| return self._kernel_cache[key] | |
| tiles = {a: list(sched.tiles[a]) for a in self.AXES} | |
| ext = self.extents | |
| innermost_var = {a: f"{a}{len(tiles[a]) - 1}" for a in self.AXES} | |
| ti, tj, tk = tiles["i"][-1], tiles["j"][-1], tiles["k"][-1] | |
| def loop_header(loop_name: str) -> str: | |
| axis = loop_name[0] | |
| lvl = int(loop_name[1:]) | |
| if lvl == 0: | |
| return f"for {loop_name} in range(0, {ext[axis]}, {tiles[axis][0]}):" | |
| parent = f"{axis}{lvl - 1}" | |
| return f"for {loop_name} in range({parent}, {parent}+{tiles[axis][lvl - 1]}, {tiles[axis][lvl]}):" | |
| lines = ["def kernel(A, B, C):"] | |
| indent = " " | |
| if debug_checks: | |
| lines.append(indent + f"assert A.shape == ({ext['i']}, {ext['k']})") | |
| lines.append(indent + f"assert B.shape == ({ext['k']}, {ext['j']})") | |
| lines.append(indent + f"assert C.shape == ({ext['i']}, {ext['j']})") | |
| if use_matmul_out: | |
| lines.append(indent + f"tmp = np.empty(({ti}, {tj}), dtype=C.dtype)") | |
| cur = indent | |
| for ln in sched.order: | |
| lines.append(cur + loop_header(ln)) | |
| cur += indent | |
| iv, jv, kv = innermost_var["i"], innermost_var["j"], innermost_var["k"] | |
| lines.append(cur + f"i = {iv}; j = {jv}; k = {kv}") | |
| if use_matmul_out: | |
| lines.append(cur + f"np.matmul(A[i:i+{ti}, k:k+{tk}], B[k:k+{tk}, j:j+{tj}], out=tmp)") | |
| lines.append(cur + f"C[i:i+{ti}, j:j+{tj}] += tmp") | |
| else: | |
| lines.append(cur + f"C[i:i+{ti}, j:j+{tj}] += A[i:i+{ti}, k:k+{tk}] @ B[k:k+{tk}, j:j+{tj}]") | |
| src = "\n".join(lines) | |
| loc: Dict[str, Any] = {} | |
| exec(src, {"np": np}, loc) | |
| kernel = loc["kernel"] | |
| if not return_source: | |
| self._kernel_cache[key] = kernel | |
| return kernel | |
| return kernel, src | |
| @staticmethod | |
| def _randn(rng: np.random.Generator, shape: Tuple[int, ...], dtype) -> np.ndarray: | |
| try: | |
| return rng.standard_normal(shape, dtype=dtype) | |
| except TypeError: | |
| return rng.standard_normal(shape).astype(dtype, copy=False) | |
| def run( | |
| self, | |
| schedule: Optional[str] = None, | |
| *, | |
| A: Optional[np.ndarray] = None, | |
| B: Optional[np.ndarray] = None, | |
| dtype=np.float32, | |
| seed: int = 0, | |
| debug_checks: bool = False, | |
| use_matmul_out: bool = True, | |
| ) -> np.ndarray: | |
| """Compute C for the DSL's (m,n,k) using the given schedule.""" | |
| sched = self.parse(schedule) if schedule is not None else self.schedule | |
| kernel = self.compile(sched.to_string(), debug_checks=debug_checks, use_matmul_out=use_matmul_out) | |
| rng = np.random.default_rng(seed) | |
| if A is None: | |
| A = self._randn(rng, (self.m, self.k), dtype=dtype) | |
| if B is None: | |
| B = self._randn(rng, (self.k, self.n), dtype=dtype) | |
| C = np.zeros((self.m, self.n), dtype=dtype) | |
| kernel(A, B, C) | |
| return C | |
| def evaluate( | |
| self, | |
| schedule: Optional[str] = None, | |
| *, | |
| dtype=np.float32, | |
| seed: int = 0, | |
| n_runs: int = 3, | |
| warmup: int = 1, | |
| debug_checks: bool = False, | |
| use_matmul_out: bool = True, | |
| check_correctness: bool = False, | |
| correctness_atol: float = 1e-3, | |
| correctness_rtol: float = 1e-3, | |
| ) -> Dict[str, float]: | |
| """Time the scheduled matmul on random inputs and return seconds + GFLOP/s.""" | |
| sched = self.parse(schedule) if schedule is not None else self.schedule | |
| kernel = self.compile(sched.to_string(), debug_checks=debug_checks, use_matmul_out=use_matmul_out) | |
| rng = np.random.default_rng(seed) | |
| A = self._randn(rng, (self.m, self.k), dtype=dtype) | |
| B = self._randn(rng, (self.k, self.n), dtype=dtype) | |
| C = np.zeros((self.m, self.n), dtype=dtype) | |
| for _ in range(max(0, warmup)): | |
| C.fill(0) | |
| kernel(A, B, C) | |
| times = [] | |
| for _ in range(max(1, n_runs)): | |
| C.fill(0) | |
| t0 = time.perf_counter() | |
| kernel(A, B, C) | |
| times.append(time.perf_counter() - t0) | |
| if check_correctness: | |
| C_ref = A @ B | |
| if not np.allclose(C, C_ref, atol=correctness_atol, rtol=correctness_rtol): | |
| err = float(np.max(np.abs(C - C_ref))) | |
| raise AssertionError(f"Correctness check failed. max|err|={err}") | |
| ops = 2.0 * self.m * self.n * self.k | |
| best = min(times) | |
| mean = sum(times) / len(times) | |
| return { | |
| "seconds_best": float(best), | |
| "seconds_mean": float(mean), | |
| "gflops_best": float(ops / best / 1e9), | |
| "gflops_mean": float(ops / mean / 1e9), | |
| } | |
| def mutate( | |
| self, | |
| schedule: Optional[str] = None, | |
| *, | |
| seed: Optional[int] = None, | |
| n_mutations: int = 1, | |
| max_tries: int = 200, | |
| min_tile: int = 4, | |
| max_levels: int = 5, | |
| ) -> str: | |
| """ | |
| Randomly mutate a schedule string, returning a NEW valid schedule string. | |
| Mutations include: | |
| - swapping two adjacent loops in order (if still valid) | |
| - splitting the innermost tile (adds a new level) | |
| - fusing the innermost tile (removes a level) | |
| - retile the innermost tile to another divisor of its parent tile | |
| """ | |
| rng = random.Random(seed) | |
| base = self.parse(schedule) if schedule is not None else self.schedule | |
| def divisors(x: int) -> List[int]: | |
| ds = [] | |
| for d in range(1, int(math.isqrt(x)) + 1): | |
| if x % d == 0: | |
| ds.append(d) | |
| if d * d != x: | |
| ds.append(x // d) | |
| return sorted(ds) | |
| def remove_loop(order: List[str], loop: str) -> None: | |
| if loop in order: | |
| order.remove(loop) | |
| def insert_loop_after(order: List[str], after: str, loop: str) -> None: | |
| idx = order.index(after) + 1 | |
| order.insert(idx, loop) | |
| sched = base | |
| for _ in range(max(1, n_mutations)): | |
| for _attempt in range(max_tries): | |
| tiles = {a: list(sched.tiles[a]) for a in self.AXES} | |
| order = list(sched.order) | |
| can_split_axes = [ | |
| a for a in self.AXES | |
| if len(tiles[a]) < max_levels and tiles[a][-1] >= 2 * min_tile | |
| ] | |
| can_fuse_axes = [a for a in self.AXES if len(tiles[a]) > 1] | |
| ops, weights = [], [] | |
| if len(order) >= 2: | |
| ops.append("swap_adjacent"); weights.append(0.55) | |
| if can_split_axes: | |
| ops.append("split_innermost"); weights.append(0.20) | |
| if can_fuse_axes: | |
| ops.append("fuse_innermost"); weights.append(0.15) | |
| ops.append("retile_innermost"); weights.append(0.10) | |
| op = rng.choices(ops, weights=weights, k=1)[0] | |
| try: | |
| if op == "swap_adjacent": | |
| idx = rng.randrange(0, len(order) - 1) | |
| order[idx], order[idx + 1] = order[idx + 1], order[idx] | |
| elif op == "split_innermost": | |
| axis = rng.choice(can_split_axes) | |
| last = tiles[axis][-1] | |
| ds = [d for d in divisors(last) if min_tile <= d < last] | |
| if not ds: | |
| raise ValueError("cannot split") | |
| new_inner = rng.choice(ds) | |
| tiles[axis].append(new_inner) | |
| new_loop = f"{axis}{len(tiles[axis]) - 1}" | |
| after_loop = f"{axis}{len(tiles[axis]) - 2}" | |
| insert_loop_after(order, after_loop, new_loop) | |
| elif op == "fuse_innermost": | |
| axis = rng.choice(can_fuse_axes) | |
| removed_level = len(tiles[axis]) - 1 | |
| tiles[axis].pop() | |
| remove_loop(order, f"{axis}{removed_level}") | |
| elif op == "retile_innermost": | |
| axis = rng.choice(list(self.AXES)) | |
| parent = tiles[axis][-2] if len(tiles[axis]) > 1 else self.extents[axis] | |
| ds = [d for d in divisors(parent) if d >= min_tile] | |
| tiles[axis][-1] = rng.choice(ds) if ds else parent | |
| mutated = MatmulSchedule(tiles={a: tuple(tiles[a]) for a in self.AXES}, order=tuple(order)) | |
| self._validate(mutated) | |
| sched = mutated | |
| break | |
| except Exception: | |
| continue | |
| else: | |
| raise RuntimeError("Failed to produce a valid mutation (too many rejected tries).") | |
| return sched.to_string() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment