-
-
Save jet10000/73ee23d39a2a5475aba89e233bcaf234 to your computer and use it in GitHub Desktop.
microgpt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| microgpt2.py | |
| A dependency-free, single-file GPT-style language model in pure Python. | |
| Key properties: | |
| - stdlib only | |
| - explicit forward/backward kernels (no generic scalar autograd) | |
| - flat float32 parameter buffers using array('f') | |
| - fused QKV projection | |
| - causal self-attention with optional ALiBi | |
| - optional QK-Norm | |
| - tied input/output embeddings by default | |
| - ReLU^2 MLP by default | |
| - AdamW optimizer | |
| - train + sample + save/load in one file | |
| The implementation is intentionally minimal in dependencies but not minimal in capability. | |
| It is designed to be a real, runnable system rather than a teaching-only artifact. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import os | |
| import pickle | |
| import random | |
| import sys | |
| import tempfile | |
| import time | |
| import urllib.request | |
| from array import array | |
| from dataclasses import dataclass, asdict | |
| from typing import Dict, Iterable, List, Optional, Sequence, Tuple | |
| # ----------------------------- | |
| # Utilities | |
| # ----------------------------- | |
| def f32(value: float) -> float: | |
| return float(value) | |
| def prod(shape: Sequence[int]) -> int: | |
| out = 1 | |
| for dim in shape: | |
| out *= dim | |
| return out | |
| def atomic_write_bytes(path: str, payload: bytes) -> None: | |
| directory = os.path.dirname(os.path.abspath(path)) or "." | |
| os.makedirs(directory, exist_ok=True) | |
| fd, tmp_path = tempfile.mkstemp(prefix=".tmp_", dir=directory) | |
| try: | |
| with os.fdopen(fd, "wb") as handle: | |
| handle.write(payload) | |
| os.replace(tmp_path, path) | |
| except Exception: | |
| try: | |
| os.unlink(tmp_path) | |
| except OSError: | |
| pass | |
| raise | |
| @dataclass | |
| class TensorSpec: | |
| name: str | |
| shape: Tuple[int, ...] | |
| offset: int | |
| size: int | |
| @dataclass | |
| class Config: | |
| seed: int = 1337 | |
| tokenizer: str = "byte" # byte | char | |
| context: int = 128 | |
| n_layer: int = 4 | |
| d_model: int = 128 | |
| n_head: int = 4 | |
| d_ff: int = 512 | |
| position_mode: str = "alibi" # alibi | learned | |
| qk_norm: bool = True | |
| relu2: bool = True | |
| tie_embeddings: bool = True | |
| final_norm: bool = True | |
| residual_gate: bool = True | |
| init_std: float = 0.02 | |
| lr: float = 2e-3 | |
| min_lr: float = 2e-4 | |
| weight_decay: float = 0.01 | |
| beta1: float = 0.9 | |
| beta2: float = 0.95 | |
| adam_eps: float = 1e-8 | |
| grad_clip: float = 1.0 | |
| batch_docs: int = 1 | |
| micro_batch_tokens: int = 0 | |
| steps: int = 200 | |
| warmup_steps: int = 20 | |
| save_every: int = 50 | |
| eval_every: int = 25 | |
| train_ratio: float = 0.9 | |
| temperature: float = 0.9 | |
| top_k: int = 0 | |
| sample_tokens: int = 128 | |
| def validate(self) -> None: | |
| if self.n_head <= 0: | |
| raise ValueError("n_head must be positive") | |
| if self.d_model <= 0: | |
| raise ValueError("d_model must be positive") | |
| if self.d_model % self.n_head != 0: | |
| raise ValueError("d_model must be divisible by n_head") | |
| if self.context < 2: | |
| raise ValueError("context must be >= 2") | |
| if self.position_mode not in {"alibi", "learned"}: | |
| raise ValueError("position_mode must be 'alibi' or 'learned'") | |
| if self.tokenizer not in {"byte", "char"}: | |
| raise ValueError("tokenizer must be 'byte' or 'char'") | |
| if not (0.0 < self.train_ratio < 1.0): | |
| raise ValueError("train_ratio must be in (0, 1)") | |
| # ----------------------------- | |
| # Tokenizers | |
| # ----------------------------- | |
| class ByteTokenizer: | |
| bos_id = 256 | |
| vocab_size = 257 | |
| meta = {"type": "byte", "bos_id": 256, "vocab_size": 257} | |
| def encode(self, text: str) -> List[int]: | |
| return list(text.encode("utf-8", errors="replace")) | |
| def decode(self, tokens: Sequence[int]) -> str: | |
| data = bytes(t for t in tokens if 0 <= t <= 255) | |
| return data.decode("utf-8", errors="replace") | |
| def bos(self) -> int: | |
| return self.bos_id | |
| def to_meta(self) -> Dict[str, object]: | |
| return dict(self.meta) | |
| @staticmethod | |
| def from_meta(meta: Dict[str, object]) -> "ByteTokenizer": | |
| _ = meta | |
| return ByteTokenizer() | |
| class CharTokenizer: | |
| def __init__(self, chars: Sequence[str]): | |
| self.chars = list(chars) | |
| self.stoi = {ch: i for i, ch in enumerate(self.chars)} | |
| self.itos = {i: ch for i, ch in enumerate(self.chars)} | |
| self.bos_id = len(self.chars) | |
| self.vocab_size = len(self.chars) + 1 | |
| def encode(self, text: str) -> List[int]: | |
| out = [] | |
| for ch in text: | |
| if ch not in self.stoi: | |
| raise ValueError(f"character {ch!r} not present in tokenizer vocabulary") | |
| out.append(self.stoi[ch]) | |
| return out | |
| def decode(self, tokens: Sequence[int]) -> str: | |
| pieces = [] | |
| for token in tokens: | |
| if token == self.bos_id: | |
| continue | |
| pieces.append(self.itos[token]) | |
| return "".join(pieces) | |
| def bos(self) -> int: | |
| return self.bos_id | |
| def to_meta(self) -> Dict[str, object]: | |
| return { | |
| "type": "char", | |
| "chars": self.chars, | |
| "bos_id": self.bos_id, | |
| "vocab_size": self.vocab_size, | |
| } | |
| @staticmethod | |
| def from_meta(meta: Dict[str, object]) -> "CharTokenizer": | |
| chars = meta.get("chars") | |
| if not isinstance(chars, list): | |
| raise ValueError("invalid char tokenizer metadata") | |
| return CharTokenizer(chars) | |
| def build_tokenizer(tokenizer_mode: str, docs: Sequence[str]): | |
| if tokenizer_mode == "byte": | |
| return ByteTokenizer() | |
| chars = sorted(set("".join(docs))) | |
| if not chars: | |
| raise ValueError("cannot build char tokenizer from an empty dataset") | |
| return CharTokenizer(chars) | |
| def tokenizer_from_meta(meta: Dict[str, object]): | |
| tokenizer_type = meta.get("type") | |
| if tokenizer_type == "byte": | |
| return ByteTokenizer.from_meta(meta) | |
| if tokenizer_type == "char": | |
| return CharTokenizer.from_meta(meta) | |
| raise ValueError(f"unsupported tokenizer metadata type: {tokenizer_type!r}") | |
| # ----------------------------- | |
| # Dataset | |
| # ----------------------------- | |
| DEFAULT_NAMES_URL = "https://raw.githubusercontent.com/karpathy/makemore/988aa59/names.txt" | |
| DEFAULT_FALLBACK_DOCS = [ | |
| "emma", "olivia", "ava", "isabella", "sophia", | |
| "mia", "charlotte", "amelia", "harper", "evelyn", | |
| "liam", "noah", "oliver", "elijah", "james", | |
| "william", "benjamin", "lucas", "henry", "theodore", | |
| ] | |
| def ensure_input_file(path: str) -> str: | |
| if os.path.exists(path): | |
| return path | |
| if os.path.basename(path) == "input.txt": | |
| print("input.txt not found; downloading the original tiny names dataset...", file=sys.stderr) | |
| try: | |
| urllib.request.urlretrieve(DEFAULT_NAMES_URL, path) | |
| except Exception as exc: # pragma: no cover - network availability is environment-specific | |
| print(f"download failed ({exc}); writing embedded fallback dataset instead", file=sys.stderr) | |
| with open(path, "w", encoding="utf-8") as handle: | |
| handle.write("\n".join(DEFAULT_FALLBACK_DOCS) + "\n") | |
| return path | |
| raise FileNotFoundError(path) | |
| def load_docs(path: str) -> List[str]: | |
| ensure_input_file(path) | |
| with open(path, "r", encoding="utf-8") as handle: | |
| docs = [line.strip("\n\r") for line in handle] | |
| docs = [doc for doc in docs if doc] | |
| if not docs: | |
| raise ValueError("dataset is empty after stripping blank lines") | |
| return docs | |
| class DocumentDataset: | |
| def __init__(self, docs: Sequence[str], tokenizer, context: int, seed: int = 1337): | |
| if context < 2: | |
| raise ValueError("context must be >= 2") | |
| self.docs = list(docs) | |
| self.tokenizer = tokenizer | |
| self.context = context | |
| self.rng = random.Random(seed) | |
| self.encoded_docs: List[List[int]] = [] | |
| bos = tokenizer.bos() | |
| for doc in self.docs: | |
| toks = [bos] + tokenizer.encode(doc) + [bos] | |
| if len(toks) >= 2: | |
| self.encoded_docs.append(toks) | |
| if not self.encoded_docs: | |
| raise ValueError("no valid documents after tokenization") | |
| def sample(self, num_docs: int = 1) -> Tuple[List[int], List[int], int]: | |
| if num_docs <= 0: | |
| raise ValueError("num_docs must be positive") | |
| inputs: List[int] = [] | |
| targets: List[int] = [] | |
| tokens_used = 0 | |
| for _ in range(num_docs): | |
| seq = self.rng.choice(self.encoded_docs) | |
| if len(seq) <= self.context + 1: | |
| chunk = seq | |
| else: | |
| start = self.rng.randint(0, len(seq) - (self.context + 1)) | |
| chunk = seq[start:start + self.context + 1] | |
| chunk_inputs = chunk[:-1] | |
| chunk_targets = chunk[1:] | |
| if len(chunk_inputs) < 1: | |
| continue | |
| inputs.extend(chunk_inputs) | |
| targets.extend(chunk_targets) | |
| tokens_used += len(chunk_inputs) | |
| if not inputs or not targets: | |
| raise RuntimeError("failed to sample a non-empty training sequence") | |
| return inputs, targets, tokens_used | |
| def split(self, train_ratio: float) -> Tuple["DocumentDataset", "DocumentDataset"]: | |
| n = len(self.docs) | |
| pivot = max(1, min(n - 1, int(round(n * train_ratio)))) | |
| train_docs = self.docs[:pivot] | |
| val_docs = self.docs[pivot:] | |
| return ( | |
| DocumentDataset(train_docs, self.tokenizer, self.context, seed=self.rng.randint(0, 2**31 - 1)), | |
| DocumentDataset(val_docs, self.tokenizer, self.context, seed=self.rng.randint(0, 2**31 - 1)), | |
| ) | |
| # ----------------------------- | |
| # Parameter store | |
| # ----------------------------- | |
| class ParameterStore: | |
| def __init__(self) -> None: | |
| self.specs: Dict[str, TensorSpec] = {} | |
| self.param = array("f") | |
| self.grad = array("f") | |
| self.m1 = array("f") | |
| self.m2 = array("f") | |
| def register(self, name: str, shape: Sequence[int], init: str, rng: random.Random, std: float, value: float = 0.0) -> TensorSpec: | |
| shape = tuple(int(x) for x in shape) | |
| size = prod(shape) | |
| offset = len(self.param) | |
| if name in self.specs: | |
| raise ValueError(f"duplicate parameter name: {name}") | |
| if init == "normal": | |
| data = [rng.gauss(0.0, std) for _ in range(size)] | |
| elif init == "zeros": | |
| data = [0.0 for _ in range(size)] | |
| elif init == "constant": | |
| data = [float(value) for _ in range(size)] | |
| else: | |
| raise ValueError(f"unknown init mode: {init}") | |
| self.param.extend(data) | |
| self.grad.extend([0.0] * size) | |
| self.m1.extend([0.0] * size) | |
| self.m2.extend([0.0] * size) | |
| spec = TensorSpec(name=name, shape=shape, offset=offset, size=size) | |
| self.specs[name] = spec | |
| return spec | |
| def spec(self, name: str) -> TensorSpec: | |
| try: | |
| return self.specs[name] | |
| except KeyError as exc: | |
| raise KeyError(f"parameter not found: {name}") from exc | |
| def zero_grad(self) -> None: | |
| for i in range(len(self.grad)): | |
| self.grad[i] = 0.0 | |
| def state_dict(self) -> Dict[str, object]: | |
| return { | |
| "specs": {name: asdict(spec) for name, spec in self.specs.items()}, | |
| "param_bytes": self.param.tobytes(), | |
| "m1_bytes": self.m1.tobytes(), | |
| "m2_bytes": self.m2.tobytes(), | |
| } | |
| @staticmethod | |
| def from_state(state: Dict[str, object]) -> "ParameterStore": | |
| store = ParameterStore() | |
| specs_raw = state["specs"] | |
| if not isinstance(specs_raw, dict): | |
| raise ValueError("invalid checkpoint: missing specs") | |
| store.specs = {name: TensorSpec(**spec) for name, spec in specs_raw.items()} | |
| store.param = array("f") | |
| store.param.frombytes(state["param_bytes"]) | |
| store.grad = array("f", [0.0] * len(store.param)) | |
| store.m1 = array("f") | |
| store.m1.frombytes(state["m1_bytes"]) | |
| store.m2 = array("f") | |
| store.m2.frombytes(state["m2_bytes"]) | |
| return store | |
| # ----------------------------- | |
| # Model | |
| # ----------------------------- | |
| class MicroGPT2: | |
| def __init__(self, config: Config, vocab_size: int): | |
| config.validate() | |
| self.cfg = config | |
| self.vocab_size = int(vocab_size) | |
| self.d_head = self.cfg.d_model // self.cfg.n_head | |
| self.scale = 1.0 / math.sqrt(self.d_head) | |
| self.rng = random.Random(self.cfg.seed) | |
| self.store = ParameterStore() | |
| self._build_parameters() | |
| self.alibi_slopes = self._make_alibi_slopes(self.cfg.n_head) | |
| def _build_parameters(self) -> None: | |
| c = self.cfg | |
| s = self.store | |
| s.register("wte", (self.vocab_size, c.d_model), "normal", self.rng, c.init_std) | |
| if not c.tie_embeddings: | |
| s.register("lm_head", (self.vocab_size, c.d_model), "normal", self.rng, c.init_std) | |
| if c.position_mode == "learned": | |
| s.register("wpe", (c.context, c.d_model), "normal", self.rng, c.init_std) | |
| for li in range(c.n_layer): | |
| s.register(f"layer{li}.w_qkv", (3 * c.d_model, c.d_model), "normal", self.rng, c.init_std) | |
| s.register(f"layer{li}.w_o", (c.d_model, c.d_model), "normal", self.rng, c.init_std) | |
| s.register(f"layer{li}.w1", (c.d_ff, c.d_model), "normal", self.rng, c.init_std) | |
| s.register(f"layer{li}.w2", (c.d_model, c.d_ff), "normal", self.rng, c.init_std) | |
| if c.qk_norm: | |
| s.register(f"layer{li}.attn_alpha", (c.n_head,), "constant", self.rng, c.init_std, value=1.0) | |
| if c.residual_gate: | |
| gate_init = 1.0 / math.sqrt(2.0 * max(1, c.n_layer)) | |
| s.register(f"layer{li}.gamma_attn", (1,), "constant", self.rng, c.init_std, value=gate_init) | |
| s.register(f"layer{li}.gamma_mlp", (1,), "constant", self.rng, c.init_std, value=gate_init) | |
| @staticmethod | |
| def _make_alibi_slopes(n_head: int) -> List[float]: | |
| if n_head <= 0: | |
| raise ValueError("n_head must be positive") | |
| def get_slopes_power_of_2(power: int) -> List[float]: | |
| start = 2.0 ** (-(2.0 ** -(math.log2(power) - 3.0))) | |
| ratio = start | |
| return [start * (ratio ** i) for i in range(power)] | |
| if (n_head & (n_head - 1)) == 0: | |
| return get_slopes_power_of_2(n_head) | |
| lower = 2 ** int(math.floor(math.log2(n_head))) | |
| slopes = get_slopes_power_of_2(lower) | |
| extra = get_slopes_power_of_2(2 * lower)[0::2] | |
| slopes.extend(extra[: n_head - lower]) | |
| return slopes | |
| # ---------- low-level kernels ---------- | |
| def _linear_fwd(self, x: List[float], w_spec: TensorSpec, rows: int, in_dim: int, out_dim: int) -> List[float]: | |
| p = self.store.param | |
| y = [0.0] * (rows * out_dim) | |
| woff = w_spec.offset | |
| for r in range(rows): | |
| xoff = r * in_dim | |
| yoff = r * out_dim | |
| for o in range(out_dim): | |
| base = woff + o * in_dim | |
| acc = 0.0 | |
| for i in range(in_dim): | |
| acc += x[xoff + i] * p[base + i] | |
| y[yoff + o] = acc | |
| return y | |
| def _linear_bwd(self, dout: List[float], x: List[float], w_spec: TensorSpec, rows: int, in_dim: int, out_dim: int) -> List[float]: | |
| p = self.store.param | |
| g = self.store.grad | |
| dx = [0.0] * (rows * in_dim) | |
| woff = w_spec.offset | |
| for r in range(rows): | |
| xoff = r * in_dim | |
| yoff = r * out_dim | |
| for o in range(out_dim): | |
| go = dout[yoff + o] | |
| if go == 0.0: | |
| continue | |
| base = woff + o * in_dim | |
| for i in range(in_dim): | |
| dx[xoff + i] += go * p[base + i] | |
| g[base + i] += go * x[xoff + i] | |
| return dx | |
| def _rmsnorm_fwd(self, x: List[float], rows: int, dim: int, eps: float = 1e-5) -> Tuple[List[float], List[float]]: | |
| y = [0.0] * (rows * dim) | |
| inv = [0.0] * rows | |
| for r in range(rows): | |
| off = r * dim | |
| ms = 0.0 | |
| for i in range(dim): | |
| xi = x[off + i] | |
| ms += xi * xi | |
| ms /= dim | |
| rinv = 1.0 / math.sqrt(ms + eps) | |
| inv[r] = rinv | |
| for i in range(dim): | |
| y[off + i] = x[off + i] * rinv | |
| return y, inv | |
| def _rmsnorm_bwd(self, dout: List[float], x: List[float], inv: List[float], rows: int, dim: int) -> List[float]: | |
| dx = [0.0] * (rows * dim) | |
| for r in range(rows): | |
| off = r * dim | |
| rinv = inv[r] | |
| dot = 0.0 | |
| for i in range(dim): | |
| dot += dout[off + i] * x[off + i] | |
| coeff = (rinv ** 3) * dot / dim | |
| for i in range(dim): | |
| dx[off + i] = dout[off + i] * rinv - x[off + i] * coeff | |
| return dx | |
| def _relu2_fwd(self, x: List[float]) -> Tuple[List[float], List[float]]: | |
| y = [0.0] * len(x) | |
| positive = [0.0] * len(x) | |
| for i, xv in enumerate(x): | |
| p = xv if xv > 0.0 else 0.0 | |
| positive[i] = p | |
| y[i] = p * p | |
| return y, positive | |
| def _relu2_bwd(self, dout: List[float], positive: List[float]) -> List[float]: | |
| dx = [0.0] * len(dout) | |
| for i, go in enumerate(dout): | |
| p = positive[i] | |
| dx[i] = go * (2.0 * p) | |
| return dx | |
| def _embedding_fwd(self, tokens: Sequence[int]) -> List[float]: | |
| c = self.cfg | |
| if c.position_mode == "learned" and len(tokens) > c.context: | |
| raise ValueError("learned position mode cannot process a sequence longer than configured context") | |
| spec = self.store.spec("wte") | |
| p = self.store.param | |
| out = [0.0] * (len(tokens) * c.d_model) | |
| for t, token in enumerate(tokens): | |
| if token < 0 or token >= self.vocab_size: | |
| raise ValueError(f"token {token} out of vocabulary range 0..{self.vocab_size - 1}") | |
| src = spec.offset + token * c.d_model | |
| dst = t * c.d_model | |
| for i in range(c.d_model): | |
| out[dst + i] = p[src + i] | |
| if c.position_mode == "learned": | |
| wpe = self.store.spec("wpe") | |
| for t in range(len(tokens)): | |
| src = wpe.offset + t * c.d_model | |
| dst = t * c.d_model | |
| for i in range(c.d_model): | |
| out[dst + i] += p[src + i] | |
| return out | |
| def _embedding_bwd(self, tokens: Sequence[int], d_embed: List[float]) -> None: | |
| c = self.cfg | |
| g = self.store.grad | |
| wte = self.store.spec("wte") | |
| for t, token in enumerate(tokens): | |
| dst = wte.offset + token * c.d_model | |
| src = t * c.d_model | |
| for i in range(c.d_model): | |
| g[dst + i] += d_embed[src + i] | |
| if c.position_mode == "learned": | |
| wpe = self.store.spec("wpe") | |
| for t in range(len(tokens)): | |
| dst = wpe.offset + t * c.d_model | |
| src = t * c.d_model | |
| for i in range(c.d_model): | |
| g[dst + i] += d_embed[src + i] | |
| def _qk_norm_fwd(self, x: List[float], rows: int, n_head: int, d_head: int, eps: float = 1e-5) -> Tuple[List[float], List[float]]: | |
| y = [0.0] * len(x) | |
| inv = [0.0] * (rows * n_head) | |
| for r in range(rows): | |
| for h in range(n_head): | |
| base = (r * n_head + h) * d_head | |
| ms = 0.0 | |
| for i in range(d_head): | |
| xv = x[base + i] | |
| ms += xv * xv | |
| ms /= d_head | |
| rinv = 1.0 / math.sqrt(ms + eps) | |
| inv[r * n_head + h] = rinv | |
| for i in range(d_head): | |
| y[base + i] = x[base + i] * rinv | |
| return y, inv | |
| def _qk_norm_bwd(self, dout: List[float], x: List[float], inv: List[float], rows: int, n_head: int, d_head: int) -> List[float]: | |
| dx = [0.0] * len(dout) | |
| for r in range(rows): | |
| for h in range(n_head): | |
| idx = r * n_head + h | |
| base = idx * d_head | |
| rinv = inv[idx] | |
| dot = 0.0 | |
| for i in range(d_head): | |
| dot += dout[base + i] * x[base + i] | |
| coeff = (rinv ** 3) * dot / d_head | |
| for i in range(d_head): | |
| dx[base + i] = dout[base + i] * rinv - x[base + i] * coeff | |
| return dx | |
| def _split_qkv(self, qkv: List[float], rows: int) -> Tuple[List[float], List[float], List[float]]: | |
| d = self.cfg.d_model | |
| q = [0.0] * (rows * d) | |
| k = [0.0] * (rows * d) | |
| v = [0.0] * (rows * d) | |
| for r in range(rows): | |
| base_in = r * (3 * d) | |
| base_out = r * d | |
| for i in range(d): | |
| q[base_out + i] = qkv[base_in + i] | |
| k[base_out + i] = qkv[base_in + d + i] | |
| v[base_out + i] = qkv[base_in + 2 * d + i] | |
| return q, k, v | |
| def _merge_qkv_grads(self, dq: List[float], dk: List[float], dv: List[float], rows: int) -> List[float]: | |
| d = self.cfg.d_model | |
| out = [0.0] * (rows * 3 * d) | |
| for r in range(rows): | |
| base_out = r * 3 * d | |
| base_in = r * d | |
| for i in range(d): | |
| out[base_out + i] = dq[base_in + i] | |
| out[base_out + d + i] = dk[base_in + i] | |
| out[base_out + 2 * d + i] = dv[base_in + i] | |
| return out | |
| def _attention_fwd(self, q: List[float], k: List[float], v: List[float], rows: int, layer_idx: int) -> Tuple[List[float], List[float]]: | |
| c = self.cfg | |
| out = [0.0] * (rows * c.d_model) | |
| probs = [0.0] * (c.n_head * rows * rows) | |
| alpha = None | |
| if c.qk_norm: | |
| alpha = self.store.spec(f"layer{layer_idx}.attn_alpha") | |
| p = self.store.param | |
| for h in range(c.n_head): | |
| slope = self.alibi_slopes[h] if c.position_mode == "alibi" else 0.0 | |
| alpha_h = p[alpha.offset + h] if alpha is not None else self.scale | |
| for t in range(rows): | |
| logits: List[float] = [] | |
| max_logit = -1e30 | |
| qbase = (t * c.n_head + h) * self.d_head | |
| for tau in range(t + 1): | |
| kbase = (tau * c.n_head + h) * self.d_head | |
| dot = 0.0 | |
| for i in range(self.d_head): | |
| dot += q[qbase + i] * k[kbase + i] | |
| logit = alpha_h * dot | |
| if not c.qk_norm: | |
| logit *= self.scale | |
| if c.position_mode == "alibi": | |
| logit += slope * (tau - t) | |
| logits.append(logit) | |
| if logit > max_logit: | |
| max_logit = logit | |
| exps = [math.exp(val - max_logit) for val in logits] | |
| denom = sum(exps) | |
| if denom <= 0.0: | |
| raise FloatingPointError("attention softmax denominator is non-positive") | |
| obase = t * c.d_model + h * self.d_head | |
| pbase = h * rows * rows + t * rows | |
| for tau, ex in enumerate(exps): | |
| prob = ex / denom | |
| probs[pbase + tau] = prob | |
| vbase = (tau * c.n_head + h) * self.d_head | |
| for i in range(self.d_head): | |
| out[obase + i] += prob * v[vbase + i] | |
| return out, probs | |
| def _attention_bwd(self, dout: List[float], probs: List[float], q: List[float], k: List[float], v: List[float], rows: int, layer_idx: int) -> Tuple[List[float], List[float], List[float]]: | |
| c = self.cfg | |
| dq = [0.0] * (rows * c.d_model) | |
| dk = [0.0] * (rows * c.d_model) | |
| dv = [0.0] * (rows * c.d_model) | |
| alpha_spec = self.store.spec(f"layer{layer_idx}.attn_alpha") if c.qk_norm else None | |
| p = self.store.param | |
| g = self.store.grad | |
| for h in range(c.n_head): | |
| alpha_h = p[alpha_spec.offset + h] if alpha_spec is not None else self.scale | |
| dprob = [0.0] * (rows * rows) | |
| for t in range(rows): | |
| obase = t * c.d_model + h * self.d_head | |
| for tau in range(t + 1): | |
| pidx = h * rows * rows + t * rows + tau | |
| prob = probs[pidx] | |
| vbase = (tau * c.n_head + h) * self.d_head | |
| dot = 0.0 | |
| for i in range(self.d_head): | |
| go = dout[obase + i] | |
| dot += go * v[vbase + i] | |
| dv[vbase + i] += prob * go | |
| dprob[t * rows + tau] = dot | |
| ds = [0.0] * (rows * rows) | |
| for t in range(rows): | |
| row_dot = 0.0 | |
| for tau in range(t + 1): | |
| pidx = h * rows * rows + t * rows + tau | |
| prob = probs[pidx] | |
| row_dot += prob * dprob[t * rows + tau] | |
| for tau in range(t + 1): | |
| pidx = h * rows * rows + t * rows + tau | |
| prob = probs[pidx] | |
| ds_val = prob * (dprob[t * rows + tau] - row_dot) | |
| ds[t * rows + tau] = ds_val | |
| for t in range(rows): | |
| qbase = (t * c.n_head + h) * self.d_head | |
| for tau in range(t + 1): | |
| kval = ds[t * rows + tau] | |
| if kval == 0.0: | |
| continue | |
| kbase = (tau * c.n_head + h) * self.d_head | |
| dot_qk = 0.0 | |
| for i in range(self.d_head): | |
| qv = q[qbase + i] | |
| kv = k[kbase + i] | |
| dq[qbase + i] += kval * alpha_h * kv | |
| dk[kbase + i] += kval * alpha_h * qv | |
| dot_qk += qv * kv | |
| if alpha_spec is not None: | |
| g[alpha_spec.offset + h] += kval * dot_qk | |
| return dq, dk, dv | |
| def _lm_head_spec(self) -> TensorSpec: | |
| if self.cfg.tie_embeddings: | |
| return self.store.spec("wte") | |
| return self.store.spec("lm_head") | |
| def _lm_head_fwd(self, x: List[float], rows: int) -> List[float]: | |
| head = self._lm_head_spec() | |
| p = self.store.param | |
| d = self.cfg.d_model | |
| logits = [0.0] * (rows * self.vocab_size) | |
| for r in range(rows): | |
| xoff = r * d | |
| loff = r * self.vocab_size | |
| for token in range(self.vocab_size): | |
| woff = head.offset + token * d | |
| acc = 0.0 | |
| for i in range(d): | |
| acc += x[xoff + i] * p[woff + i] | |
| logits[loff + token] = acc | |
| return logits | |
| def _lm_head_bwd(self, dlogits: List[float], x: List[float], rows: int) -> List[float]: | |
| d = self.cfg.d_model | |
| head = self._lm_head_spec() | |
| p = self.store.param | |
| g = self.store.grad | |
| dx = [0.0] * (rows * d) | |
| for r in range(rows): | |
| xoff = r * d | |
| loff = r * self.vocab_size | |
| for token in range(self.vocab_size): | |
| go = dlogits[loff + token] | |
| if go == 0.0: | |
| continue | |
| woff = head.offset + token * d | |
| for i in range(d): | |
| dx[xoff + i] += go * p[woff + i] | |
| g[woff + i] += go * x[xoff + i] | |
| return dx | |
| def _cross_entropy(self, logits: List[float], targets: Sequence[int], rows: int) -> Tuple[float, List[float]]: | |
| if len(targets) != rows: | |
| raise ValueError("target length must match number of rows") | |
| dlogits = [0.0] * len(logits) | |
| loss = 0.0 | |
| inv_rows = 1.0 / rows | |
| for r in range(rows): | |
| base = r * self.vocab_size | |
| max_logit = max(logits[base:base + self.vocab_size]) | |
| denom = 0.0 | |
| for i in range(self.vocab_size): | |
| denom += math.exp(logits[base + i] - max_logit) | |
| log_denom = math.log(denom) | |
| target = targets[r] | |
| loss += -(logits[base + target] - max_logit - log_denom) | |
| for i in range(self.vocab_size): | |
| prob = math.exp(logits[base + i] - max_logit - log_denom) | |
| dlogits[base + i] = prob * inv_rows | |
| dlogits[base + target] -= inv_rows | |
| return loss * inv_rows, dlogits | |
| # ---------- forward/backward ---------- | |
| def forward_backward(self, tokens: Sequence[int], targets: Sequence[int]) -> float: | |
| if len(tokens) != len(targets): | |
| raise ValueError("tokens and targets must have identical lengths") | |
| if not tokens: | |
| raise ValueError("tokens must be non-empty") | |
| T = len(tokens) | |
| c = self.cfg | |
| store = self.store | |
| store.zero_grad() | |
| h = self._embedding_fwd(tokens) | |
| caches: List[Dict[str, object]] = [] | |
| for li in range(c.n_layer): | |
| layer_cache: Dict[str, object] = {"x_in": h[:]} | |
| x_norm1, inv1 = self._rmsnorm_fwd(h, T, c.d_model) | |
| layer_cache["x_norm1"] = x_norm1 | |
| layer_cache["inv1"] = inv1 | |
| qkv = self._linear_fwd(x_norm1, store.spec(f"layer{li}.w_qkv"), T, c.d_model, 3 * c.d_model) | |
| q, k, v = self._split_qkv(qkv, T) | |
| layer_cache["q_raw"] = q | |
| layer_cache["k_raw"] = k | |
| layer_cache["v_raw"] = v | |
| if c.qk_norm: | |
| qn, q_inv = self._qk_norm_fwd(q, T, c.n_head, self.d_head) | |
| kn, k_inv = self._qk_norm_fwd(k, T, c.n_head, self.d_head) | |
| else: | |
| qn, kn = q, k | |
| q_inv, k_inv = [], [] | |
| layer_cache["q"] = qn | |
| layer_cache["k"] = kn | |
| layer_cache["q_inv"] = q_inv | |
| layer_cache["k_inv"] = k_inv | |
| attn_cat, probs = self._attention_fwd(qn, kn, v, T, li) | |
| layer_cache["attn_cat"] = attn_cat | |
| layer_cache["probs"] = probs | |
| attn_proj = self._linear_fwd(attn_cat, store.spec(f"layer{li}.w_o"), T, c.d_model, c.d_model) | |
| layer_cache["attn_proj"] = attn_proj | |
| gamma_attn = store.param[store.spec(f"layer{li}.gamma_attn").offset] if c.residual_gate else 1.0 | |
| h_attn = [h[i] + gamma_attn * attn_proj[i] for i in range(len(h))] | |
| layer_cache["h_attn"] = h_attn | |
| layer_cache["gamma_attn"] = gamma_attn | |
| x_norm2, inv2 = self._rmsnorm_fwd(h_attn, T, c.d_model) | |
| layer_cache["x_norm2"] = x_norm2 | |
| layer_cache["inv2"] = inv2 | |
| mlp_pre = self._linear_fwd(x_norm2, store.spec(f"layer{li}.w1"), T, c.d_model, c.d_ff) | |
| layer_cache["mlp_pre"] = mlp_pre | |
| if c.relu2: | |
| mlp_act, positive = self._relu2_fwd(mlp_pre) | |
| layer_cache["mlp_pos"] = positive | |
| else: | |
| raise NotImplementedError("only relu2=True is implemented") | |
| layer_cache["mlp_act"] = mlp_act | |
| mlp_proj = self._linear_fwd(mlp_act, store.spec(f"layer{li}.w2"), T, c.d_ff, c.d_model) | |
| layer_cache["mlp_proj"] = mlp_proj | |
| gamma_mlp = store.param[store.spec(f"layer{li}.gamma_mlp").offset] if c.residual_gate else 1.0 | |
| h = [h_attn[i] + gamma_mlp * mlp_proj[i] for i in range(len(h_attn))] | |
| layer_cache["gamma_mlp"] = gamma_mlp | |
| layer_cache["x_out"] = h[:] | |
| caches.append(layer_cache) | |
| if c.final_norm: | |
| hf, final_inv = self._rmsnorm_fwd(h, T, c.d_model) | |
| else: | |
| hf, final_inv = h[:], [] | |
| logits = self._lm_head_fwd(hf, T) | |
| loss, dlogits = self._cross_entropy(logits, targets, T) | |
| dh = self._lm_head_bwd(dlogits, hf, T) | |
| if c.final_norm: | |
| dh = self._rmsnorm_bwd(dh, h, final_inv, T, c.d_model) | |
| for li in reversed(range(c.n_layer)): | |
| cache = caches[li] | |
| gamma_mlp = cache["gamma_mlp"] | |
| gamma_attn = cache["gamma_attn"] | |
| d_h_attn = dh[:] | |
| d_mlp_proj = [val * gamma_mlp for val in dh] | |
| if c.residual_gate: | |
| gamma_spec = store.spec(f"layer{li}.gamma_mlp") | |
| gate_grad = 0.0 | |
| mlp_proj = cache["mlp_proj"] | |
| for i in range(len(dh)): | |
| gate_grad += dh[i] * mlp_proj[i] | |
| store.grad[gamma_spec.offset] += gate_grad | |
| d_mlp_act = self._linear_bwd(d_mlp_proj, cache["mlp_act"], store.spec(f"layer{li}.w2"), T, c.d_ff, c.d_model) | |
| d_mlp_pre = self._relu2_bwd(d_mlp_act, cache["mlp_pos"]) | |
| d_x_norm2 = self._linear_bwd(d_mlp_pre, cache["x_norm2"], store.spec(f"layer{li}.w1"), T, c.d_model, c.d_ff) | |
| d_h_attn_norm = self._rmsnorm_bwd(d_x_norm2, cache["h_attn"], cache["inv2"], T, c.d_model) | |
| d_h_attn = [d_h_attn[i] + d_h_attn_norm[i] for i in range(len(d_h_attn))] | |
| d_h_in = d_h_attn[:] | |
| d_attn_proj = [val * gamma_attn for val in d_h_attn] | |
| if c.residual_gate: | |
| gamma_spec = store.spec(f"layer{li}.gamma_attn") | |
| gate_grad = 0.0 | |
| attn_proj = cache["attn_proj"] | |
| for i in range(len(d_h_attn)): | |
| gate_grad += d_h_attn[i] * attn_proj[i] | |
| store.grad[gamma_spec.offset] += gate_grad | |
| d_attn_cat = self._linear_bwd(d_attn_proj, cache["attn_cat"], store.spec(f"layer{li}.w_o"), T, c.d_model, c.d_model) | |
| dq, dk, dv = self._attention_bwd(d_attn_cat, cache["probs"], cache["q"], cache["k"], cache["v_raw"], T, li) | |
| if c.qk_norm: | |
| dq = self._qk_norm_bwd(dq, cache["q_raw"], cache["q_inv"], T, c.n_head, self.d_head) | |
| dk = self._qk_norm_bwd(dk, cache["k_raw"], cache["k_inv"], T, c.n_head, self.d_head) | |
| dqkv = self._merge_qkv_grads(dq, dk, dv, T) | |
| d_x_norm1 = self._linear_bwd(dqkv, cache["x_norm1"], store.spec(f"layer{li}.w_qkv"), T, c.d_model, 3 * c.d_model) | |
| d_h_norm = self._rmsnorm_bwd(d_x_norm1, cache["x_in"], cache["inv1"], T, c.d_model) | |
| dh = [d_h_in[i] + d_h_norm[i] for i in range(len(d_h_in))] | |
| self._embedding_bwd(tokens, dh) | |
| return loss | |
| # ---------- optimizer / training ---------- | |
| def clip_grads(self, clip: float) -> float: | |
| if clip <= 0.0: | |
| return 0.0 | |
| total = 0.0 | |
| for val in self.store.grad: | |
| total += val * val | |
| total = math.sqrt(total) | |
| if total > clip: | |
| scale = clip / (total + 1e-12) | |
| for i in range(len(self.store.grad)): | |
| self.store.grad[i] *= scale | |
| return total | |
| def adamw_step(self, step_idx: int) -> float: | |
| c = self.cfg | |
| if step_idx < c.warmup_steps: | |
| lr = c.lr * (step_idx + 1) / max(1, c.warmup_steps) | |
| else: | |
| denom = max(1, c.steps - c.warmup_steps) | |
| progress = min(1.0, max(0.0, (step_idx - c.warmup_steps) / denom)) | |
| lr = c.lr + (c.min_lr - c.lr) * progress | |
| p = self.store.param | |
| g = self.store.grad | |
| m1 = self.store.m1 | |
| m2 = self.store.m2 | |
| beta1 = c.beta1 | |
| beta2 = c.beta2 | |
| eps = c.adam_eps | |
| t = step_idx + 1 | |
| bc1 = 1.0 - beta1 ** t | |
| bc2 = 1.0 - beta2 ** t | |
| inv_bc1 = 1.0 / bc1 | |
| inv_bc2 = 1.0 / bc2 | |
| for i in range(len(p)): | |
| gi = g[i] | |
| m1[i] = beta1 * m1[i] + (1.0 - beta1) * gi | |
| m2[i] = beta2 * m2[i] + (1.0 - beta2) * gi * gi | |
| mhat = m1[i] * inv_bc1 | |
| vhat = m2[i] * inv_bc2 | |
| update = mhat / (math.sqrt(vhat) + eps) | |
| if c.weight_decay != 0.0: | |
| update += c.weight_decay * p[i] | |
| p[i] -= lr * update | |
| g[i] = 0.0 | |
| return lr | |
| def evaluate(self, dataset: DocumentDataset, batches: int = 4) -> float: | |
| losses = [] | |
| for _ in range(batches): | |
| tokens, targets, _ = dataset.sample(num_docs=max(1, self.cfg.batch_docs)) | |
| loss = self.forward_backward(tokens, targets) | |
| losses.append(loss) | |
| self.store.zero_grad() | |
| return sum(losses) / len(losses) | |
| # ---------- inference ---------- | |
| def _embed_token_step(self, token: int, pos: int) -> List[float]: | |
| if token < 0 or token >= self.vocab_size: | |
| raise ValueError(f"token {token} out of vocabulary range 0..{self.vocab_size - 1}") | |
| if self.cfg.position_mode == "learned" and pos >= self.cfg.context: | |
| raise ValueError("learned position mode cannot decode beyond configured context length") | |
| c = self.cfg | |
| p = self.store.param | |
| wte = self.store.spec("wte") | |
| out = [0.0] * c.d_model | |
| src = wte.offset + token * c.d_model | |
| for i in range(c.d_model): | |
| out[i] = p[src + i] | |
| if c.position_mode == "learned": | |
| wpe = self.store.spec("wpe") | |
| src = wpe.offset + pos * c.d_model | |
| for i in range(c.d_model): | |
| out[i] += p[src + i] | |
| return out | |
| def init_kv_cache(self) -> Dict[str, object]: | |
| return { | |
| "keys": [[] for _ in range(self.cfg.n_layer)], | |
| "values": [[] for _ in range(self.cfg.n_layer)], | |
| "positions": [], | |
| "next_pos": 0, | |
| } | |
| def decode_step(self, token: int, cache: Dict[str, object]) -> List[float]: | |
| c = self.cfg | |
| pos = int(cache["next_pos"]) | |
| if len(cache["positions"]) >= c.context: | |
| cache["positions"].pop(0) | |
| for li in range(c.n_layer): | |
| cache["keys"][li].pop(0) | |
| cache["values"][li].pop(0) | |
| current_positions = list(cache["positions"]) + [pos] | |
| x = self._embed_token_step(token, pos) | |
| for li in range(c.n_layer): | |
| x_norm1, _ = self._rmsnorm_fwd(x, 1, c.d_model) | |
| qkv = self._linear_fwd(x_norm1, self.store.spec(f"layer{li}.w_qkv"), 1, c.d_model, 3 * c.d_model) | |
| q, k, v = self._split_qkv(qkv, 1) | |
| if c.qk_norm: | |
| q, _ = self._qk_norm_fwd(q, 1, c.n_head, self.d_head) | |
| k, _ = self._qk_norm_fwd(k, 1, c.n_head, self.d_head) | |
| cache["keys"][li].append(k[:]) | |
| cache["values"][li].append(v[:]) | |
| attn_cat = [0.0] * c.d_model | |
| alpha_spec = self.store.spec(f"layer{li}.attn_alpha") if c.qk_norm else None | |
| for h in range(c.n_head): | |
| slope = self.alibi_slopes[h] if c.position_mode == "alibi" else 0.0 | |
| alpha_h = self.store.param[alpha_spec.offset + h] if alpha_spec is not None else self.scale | |
| qbase = h * self.d_head | |
| logits = [] | |
| max_logit = -1e30 | |
| for tau_idx, (k_tok, _v_tok, tau_pos) in enumerate(zip(cache["keys"][li], cache["values"][li], current_positions)): | |
| kbase = h * self.d_head | |
| dot = 0.0 | |
| for i in range(self.d_head): | |
| dot += q[qbase + i] * k_tok[kbase + i] | |
| logit = alpha_h * dot | |
| if not c.qk_norm: | |
| logit *= self.scale | |
| if c.position_mode == "alibi": | |
| logit += slope * (tau_pos - pos) | |
| logits.append((tau_idx, logit)) | |
| if logit > max_logit: | |
| max_logit = logit | |
| exps = [math.exp(logit - max_logit) for _idx, logit in logits] | |
| denom = sum(exps) | |
| if denom <= 0.0: | |
| raise FloatingPointError("attention denominator became non-positive during decode") | |
| for ex, (tau_idx, _logit) in zip(exps, logits): | |
| prob = ex / denom | |
| v_tok = cache["values"][li][tau_idx] | |
| obase = h * self.d_head | |
| vbase = h * self.d_head | |
| for i in range(self.d_head): | |
| attn_cat[obase + i] += prob * v_tok[vbase + i] | |
| attn_proj = self._linear_fwd(attn_cat, self.store.spec(f"layer{li}.w_o"), 1, c.d_model, c.d_model) | |
| gamma_attn = self.store.param[self.store.spec(f"layer{li}.gamma_attn").offset] if c.residual_gate else 1.0 | |
| x_attn = [x[i] + gamma_attn * attn_proj[i] for i in range(c.d_model)] | |
| x_norm2, _ = self._rmsnorm_fwd(x_attn, 1, c.d_model) | |
| mlp_pre = self._linear_fwd(x_norm2, self.store.spec(f"layer{li}.w1"), 1, c.d_model, c.d_ff) | |
| mlp_act, _ = self._relu2_fwd(mlp_pre) | |
| mlp_proj = self._linear_fwd(mlp_act, self.store.spec(f"layer{li}.w2"), 1, c.d_ff, c.d_model) | |
| gamma_mlp = self.store.param[self.store.spec(f"layer{li}.gamma_mlp").offset] if c.residual_gate else 1.0 | |
| x = [x_attn[i] + gamma_mlp * mlp_proj[i] for i in range(c.d_model)] | |
| if c.final_norm: | |
| x, _ = self._rmsnorm_fwd(x, 1, c.d_model) | |
| logits = self._lm_head_fwd(x, 1) | |
| cache["positions"].append(pos) | |
| cache["next_pos"] = pos + 1 | |
| return logits | |
| def prefill_cache(self, tokens: Sequence[int], cache: Optional[Dict[str, object]] = None) -> Tuple[Dict[str, object], List[float]]: | |
| if cache is None: | |
| cache = self.init_kv_cache() | |
| last_logits: Optional[List[float]] = None | |
| for token in tokens: | |
| last_logits = self.decode_step(token, cache) | |
| if last_logits is None: | |
| raise ValueError("prefill requires at least one token") | |
| return cache, last_logits | |
| def _forward_inference(self, tokens: Sequence[int]) -> List[float]: | |
| cache, logits = self.prefill_cache(tokens, self.init_kv_cache()) | |
| _ = cache | |
| return logits | |
| def sample(self, tokenizer, prompt: str, max_new_tokens: int, temperature: float = 1.0, top_k: int = 0) -> str: | |
| if temperature <= 0.0: | |
| raise ValueError("temperature must be positive") | |
| bos = tokenizer.bos() | |
| prompt_tokens = [bos] + tokenizer.encode(prompt) | |
| if len(prompt_tokens) > self.cfg.context: | |
| prompt_tokens = prompt_tokens[-self.cfg.context:] | |
| rng = random.Random(self.cfg.seed) | |
| cache, last_logits = self.prefill_cache(prompt_tokens, self.init_kv_cache()) | |
| tokens = list(prompt_tokens) | |
| for _ in range(max_new_tokens): | |
| next_token = sample_from_logits(last_logits, rng=rng, temperature=temperature, top_k=top_k) | |
| tokens.append(next_token) | |
| last_logits = self.decode_step(next_token, cache) | |
| return tokenizer.decode(tokens[1:]) | |
| # ---------- checkpointing ---------- | |
| def save(self, path: str, tokenizer_meta: Dict[str, object], extra: Optional[Dict[str, object]] = None) -> None: | |
| payload = { | |
| "config": asdict(self.cfg), | |
| "vocab_size": self.vocab_size, | |
| "tokenizer": tokenizer_meta, | |
| "store": self.store.state_dict(), | |
| "extra": extra or {}, | |
| } | |
| atomic_write_bytes(path, pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL)) | |
| @staticmethod | |
| def load(path: str) -> Tuple["MicroGPT2", object, Dict[str, object]]: | |
| with open(path, "rb") as handle: | |
| payload = pickle.load(handle) | |
| cfg = Config(**payload["config"]) | |
| model = MicroGPT2(cfg, int(payload["vocab_size"])) | |
| model.store = ParameterStore.from_state(payload["store"]) | |
| tokenizer = tokenizer_from_meta(payload["tokenizer"]) | |
| extra = payload.get("extra", {}) | |
| return model, tokenizer, extra | |
| # ----------------------------- | |
| # Sampling helper | |
| # ----------------------------- | |
| def sample_from_logits(logits: Sequence[float], rng: Optional[random.Random] = None, temperature: float = 1.0, top_k: int = 0) -> int: | |
| if rng is None: | |
| rng = random | |
| if temperature <= 0.0: | |
| raise ValueError("temperature must be positive") | |
| scaled = [x / temperature for x in logits] | |
| if top_k > 0: | |
| top_k = max(1, min(top_k, len(scaled))) | |
| top_idx = sorted(range(len(scaled)), key=lambda i: scaled[i], reverse=True)[:top_k] | |
| mask = set(top_idx) | |
| floor = min(scaled) | |
| scaled = [val if i in mask else floor - 1e9 for i, val in enumerate(scaled)] | |
| m = max(scaled) | |
| exps = [math.exp(v - m) for v in scaled] | |
| total = sum(exps) | |
| if total <= 0.0: | |
| return int(max(range(len(logits)), key=lambda i: logits[i])) | |
| threshold = rng.random() * total | |
| running = 0.0 | |
| for i, ex in enumerate(exps): | |
| running += ex | |
| if running >= threshold: | |
| return i | |
| return len(logits) - 1 | |
| # ----------------------------- | |
| # CLI helpers | |
| # ----------------------------- | |
| def build_model_and_data(args) -> Tuple[MicroGPT2, object, DocumentDataset, DocumentDataset]: | |
| docs = load_docs(args.input) | |
| rng = random.Random(args.seed) | |
| rng.shuffle(docs) | |
| cfg = Config( | |
| seed=args.seed, | |
| tokenizer=args.tokenizer, | |
| context=args.context, | |
| n_layer=args.n_layer, | |
| d_model=args.d_model, | |
| n_head=args.n_head, | |
| d_ff=args.d_ff, | |
| position_mode=args.position_mode, | |
| qk_norm=bool(args.qk_norm), | |
| relu2=True, | |
| tie_embeddings=True, | |
| final_norm=bool(args.final_norm), | |
| residual_gate=bool(args.residual_gate), | |
| init_std=args.init_std, | |
| lr=args.lr, | |
| min_lr=args.min_lr, | |
| weight_decay=args.weight_decay, | |
| beta1=args.beta1, | |
| beta2=args.beta2, | |
| adam_eps=args.adam_eps, | |
| grad_clip=args.grad_clip, | |
| batch_docs=args.batch_docs, | |
| steps=args.steps, | |
| warmup_steps=args.warmup_steps, | |
| save_every=args.save_every, | |
| eval_every=args.eval_every, | |
| train_ratio=args.train_ratio, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| sample_tokens=args.sample_tokens, | |
| ) | |
| tokenizer = build_tokenizer(cfg.tokenizer, docs) | |
| dataset = DocumentDataset(docs, tokenizer, cfg.context, seed=cfg.seed) | |
| train_dataset, val_dataset = dataset.split(cfg.train_ratio) | |
| model = MicroGPT2(cfg, tokenizer.vocab_size) | |
| return model, tokenizer, train_dataset, val_dataset | |
| def train_main(args) -> int: | |
| model, tokenizer, train_dataset, val_dataset = build_model_and_data(args) | |
| print(f"train docs: {len(train_dataset.docs)} | val docs: {len(val_dataset.docs)}") | |
| print(f"vocab size: {tokenizer.vocab_size} | params: {len(model.store.param)}") | |
| best_val = float("inf") | |
| history: List[Dict[str, float]] = [] | |
| started = time.time() | |
| for step in range(model.cfg.steps): | |
| tokens, targets, tokens_used = train_dataset.sample(num_docs=max(1, model.cfg.batch_docs)) | |
| loss = model.forward_backward(tokens, targets) | |
| grad_norm = model.clip_grads(model.cfg.grad_clip) | |
| lr = model.adamw_step(step) | |
| row = { | |
| "step": float(step + 1), | |
| "loss": float(loss), | |
| "grad_norm": float(grad_norm), | |
| "lr": float(lr), | |
| "tokens": float(tokens_used), | |
| } | |
| history.append(row) | |
| if (step + 1) % max(1, model.cfg.eval_every) == 0 or step == 0 or step + 1 == model.cfg.steps: | |
| val_loss = model.evaluate(val_dataset, batches=min(4, max(1, len(val_dataset.docs)))) | |
| row["val_loss"] = float(val_loss) | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| if args.out: | |
| model.save(args.out, tokenizer.to_meta(), extra={"history": history, "best_val": best_val}) | |
| elapsed = time.time() - started | |
| print( | |
| f"step {step + 1:5d}/{model.cfg.steps:5d} | train {loss:.4f} | val {val_loss:.4f} | " | |
| f"gnorm {grad_norm:.4f} | lr {lr:.6f} | tok {tokens_used} | {elapsed:.1f}s" | |
| ) | |
| else: | |
| print( | |
| f"step {step + 1:5d}/{model.cfg.steps:5d} | train {loss:.4f} | " | |
| f"gnorm {grad_norm:.4f} | lr {lr:.6f} | tok {tokens_used}", | |
| end="\r", | |
| flush=True, | |
| ) | |
| if args.out and ((step + 1) % max(1, model.cfg.save_every) == 0): | |
| model.save(args.out + ".latest", tokenizer.to_meta(), extra={"history": history, "best_val": best_val}) | |
| print() | |
| if args.out: | |
| model.save(args.out, tokenizer.to_meta(), extra={"history": history, "best_val": best_val}) | |
| with open(args.out + ".history.json", "w", encoding="utf-8") as handle: | |
| json.dump(history, handle, indent=2) | |
| return 0 | |
| def sample_main(args) -> int: | |
| model, tokenizer, extra = MicroGPT2.load(args.checkpoint) | |
| _ = extra | |
| text = model.sample( | |
| tokenizer, | |
| prompt=args.prompt, | |
| max_new_tokens=args.tokens, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| ) | |
| print(text) | |
| return 0 | |
| def inspect_main(args) -> int: | |
| model, tokenizer, extra = MicroGPT2.load(args.checkpoint) | |
| summary = { | |
| "config": asdict(model.cfg), | |
| "vocab_size": model.vocab_size, | |
| "param_count": len(model.store.param), | |
| "tokenizer": tokenizer.to_meta(), | |
| "extra_keys": sorted(extra.keys()), | |
| } | |
| print(json.dumps(summary, indent=2)) | |
| return 0 | |
| def make_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Dependency-free GPT-style language model") | |
| sub = parser.add_subparsers(dest="command", required=True) | |
| def add_common_train_flags(p: argparse.ArgumentParser) -> None: | |
| p.add_argument("--input", type=str, default="input.txt", help="line-delimited training text file") | |
| p.add_argument("--out", type=str, default="microgpt2.ckpt", help="checkpoint output path") | |
| p.add_argument("--seed", type=int, default=1337) | |
| p.add_argument("--tokenizer", type=str, default="byte", choices=["byte", "char"]) | |
| p.add_argument("--context", type=int, default=128) | |
| p.add_argument("--n-layer", dest="n_layer", type=int, default=4) | |
| p.add_argument("--d-model", dest="d_model", type=int, default=128) | |
| p.add_argument("--n-head", dest="n_head", type=int, default=4) | |
| p.add_argument("--d-ff", dest="d_ff", type=int, default=512) | |
| p.add_argument("--position-mode", type=str, default="alibi", choices=["alibi", "learned"]) | |
| p.add_argument("--qk-norm", type=int, default=1, choices=[0, 1]) | |
| p.add_argument("--final-norm", type=int, default=1, choices=[0, 1]) | |
| p.add_argument("--residual-gate", type=int, default=1, choices=[0, 1]) | |
| p.add_argument("--init-std", dest="init_std", type=float, default=0.02) | |
| p.add_argument("--lr", type=float, default=2e-3) | |
| p.add_argument("--min-lr", dest="min_lr", type=float, default=2e-4) | |
| p.add_argument("--weight-decay", dest="weight_decay", type=float, default=0.01) | |
| p.add_argument("--beta1", type=float, default=0.9) | |
| p.add_argument("--beta2", type=float, default=0.95) | |
| p.add_argument("--adam-eps", dest="adam_eps", type=float, default=1e-8) | |
| p.add_argument("--grad-clip", dest="grad_clip", type=float, default=1.0) | |
| p.add_argument("--batch-docs", dest="batch_docs", type=int, default=1) | |
| p.add_argument("--steps", type=int, default=200) | |
| p.add_argument("--warmup-steps", dest="warmup_steps", type=int, default=20) | |
| p.add_argument("--save-every", dest="save_every", type=int, default=50) | |
| p.add_argument("--eval-every", dest="eval_every", type=int, default=25) | |
| p.add_argument("--train-ratio", dest="train_ratio", type=float, default=0.9) | |
| p.add_argument("--temperature", type=float, default=0.9) | |
| p.add_argument("--top-k", dest="top_k", type=int, default=0) | |
| p.add_argument("--sample-tokens", dest="sample_tokens", type=int, default=128) | |
| p_train = sub.add_parser("train", help="train a model") | |
| add_common_train_flags(p_train) | |
| p_sample = sub.add_parser("sample", help="sample from a checkpoint") | |
| p_sample.add_argument("--checkpoint", type=str, required=True) | |
| p_sample.add_argument("--prompt", type=str, default="") | |
| p_sample.add_argument("--tokens", type=int, default=128) | |
| p_sample.add_argument("--temperature", type=float, default=0.9) | |
| p_sample.add_argument("--top-k", type=int, default=0) | |
| p_inspect = sub.add_parser("inspect", help="inspect checkpoint metadata") | |
| p_inspect.add_argument("--checkpoint", type=str, required=True) | |
| return parser | |
| def main(argv: Optional[Sequence[str]] = None) -> int: | |
| parser = make_parser() | |
| args = parser.parse_args(argv) | |
| if args.command == "train": | |
| return train_main(args) | |
| if args.command == "sample": | |
| return sample_main(args) | |
| if args.command == "inspect": | |
| return inspect_main(args) | |
| parser.error(f"unknown command: {args.command}") | |
| return 2 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse,math,os,pickle,random | |
| R=random.Random(0) | |
| def mat(o,i,s=.02): return [[(R.random()*2-1)*s for _ in range(i)] for _ in range(o)] | |
| def zeros(o,i): return [[0.0]*i for _ in range(o)] | |
| def mm(X,W): return [[sum(x[k]*w[k] for k in range(len(x))) for w in W] for x in X] | |
| def add(A,B): return [[a+b for a,b in zip(x,y)] for x,y in zip(A,B)] | |
| def tr(A): return list(map(list,zip(*A))) | |
| def outer(A,B): return [[sum(a[r]*B[r][c] for r in range(len(a))) for c in range(len(B[0]))] for a in tr(A)] | |
| def dxmm(D,W): return [[sum(d[j]*W[j][i] for j in range(len(W))) for i in range(len(W[0]))] for d in D] | |
| def rms(X,eps=1e-5): | |
| r=[math.sqrt(sum(v*v for v in x)/len(x)+eps) for x in X] | |
| return [[v/ri for v in x] for x,ri in zip(X,r)],(X,r) | |
| def drms(D,C): | |
| X,r=C; n=len(X[0]); O=[] | |
| for g,x,ri in zip(D,X,r): | |
| dot=sum(gi*xi for gi,xi in zip(g,x)); O.append([gi/ri-xi*dot/(n*ri**3) for gi,xi in zip(g,x)]) | |
| return O | |
| def relu2(X): return [[max(0.0,v)**2 for v in x] for x in X],X | |
| def drelu2(D,X): return [[d*(2*max(0.0,v)) for d,v in zip(g,x)] for g,x in zip(D,X)] | |
| def softmax(S): | |
| A=[] | |
| for s in S: | |
| m=max(s); e=[math.exp(v-m) for v in s]; z=sum(e); A.append([v/z for v in e]) | |
| return A | |
| def init(d=64,h=4,f=256,v=257,c=64): | |
| assert d%h==0 | |
| return {'cfg':{'d':d,'h':h,'f':f,'v':v,'c':c},'E':mat(v,d),'B':[{'Wqkv':mat(3*d,d),'Wo':mat(d,d),'W1':mat(f,d),'W2':mat(d,f)}]} | |
| def block(H,P,cfg): | |
| d,h,f=cfg['d'],cfg['h'],cfg['f']; T=len(H); dh=d//h; U,c1=rms(H); QKV=mm(U,P['Wqkv']); Q=[[[qkv[m*dh:(m+1)*dh] for m in range(h)] for qkv in [r[:d] for r in QKV]][t] for t in range(T)] | |
| K=[[[qkv[d+m*dh:d+(m+1)*dh] for m in range(h)] for qkv in QKV][t] for t in range(T)] | |
| V=[[[qkv[2*d+m*dh:2*d+(m+1)*dh] for m in range(h)] for qkv in QKV][t] for t in range(T)] | |
| A=[]; O=[]; s=1/math.sqrt(dh) | |
| for m in range(h): | |
| Sm=[]; Am=[]; Om=[]; bias=-(m+1)/h | |
| for t in range(T): | |
| row=[sum(Q[t][m][i]*K[j][m][i] for i in range(dh))*s + bias*(t-j) for j in range(t+1)] | |
| a=softmax([row])[0]; Sm.append(row); Am.append(a); Om.append([sum(a[j]*V[j][m][i] for j in range(t+1)) for i in range(dh)]) | |
| A.append(Am); O.append(Om) | |
| Oc=[[v for m in range(h) for v in O[m][t]] for t in range(T)]; O2=mm(Oc,P['Wo']); H2=add(H,O2); N,c2=rms(H2); M1=mm(N,P['W1']); M,c3=relu2(M1); M2=mm(M,P['W2']); H3=add(H2,M2) | |
| return H3,(H,U,c1,Q,K,V,A,Oc,O2,H2,N,c2,M1,c3,M,M2) | |
| def dblock(D,P,cfg,C,G): | |
| H,U,c1,Q,K,V,A,Oc,O2,H2,N,c2,M1,c3,M,M2=C; d,h,f=cfg['d'],cfg['h'],cfg['f']; T=len(H); dh=d//h; G['W2']=add(G['W2'],outer(D,M)); dM=dxmm(D,P['W2']); dM1=drelu2(dM,c3); G['W1']=add(G['W1'],outer(dM1,N)); dN=dxmm(dM1,P['W1']); dH2=add(D,drms(dN,c2)); G['Wo']=add(G['Wo'],outer(dH2,Oc)); dOc=dxmm(dH2,P['Wo']); dQKV=[[0.0]*(3*d) for _ in range(T)] | |
| dO=[[[dOc[t][m*dh:(m+1)*dh] for m in range(h)] for t in range(T)]] [0] | |
| s=1/math.sqrt(dh) | |
| for m in range(h): | |
| dV=[[0.0]*dh for _ in range(T)]; dS=[[0.0]*(t+1) for t in range(T)] | |
| for t in range(T): | |
| a=A[m][t]; go=dO[t][m] | |
| for j in range(t+1): | |
| dA=sum(go[i]*V[j][m][i] for i in range(dh)); | |
| for i in range(dh): dV[j][i]+=a[j]*go[i] | |
| dS[t][j]=dA | |
| z=sum(a[j]*dS[t][j] for j in range(t+1)); dS[t]=[a[j]*(dS[t][j]-z) for j in range(t+1)] | |
| for t in range(T): | |
| for j in range(t+1): | |
| g=dS[t][j]*s | |
| for i in range(dh): | |
| dQKV[t][m*dh+i]+=g*K[j][m][i]; dQKV[j][d+m*dh+i]+=g*Q[t][m][i]; dQKV[j][2*d+m*dh+i]+=dV[j][i] | |
| G['Wqkv']=add(G['Wqkv'],outer(dQKV,U)); dU=dxmm(dQKV,P['Wqkv']); return add(dH2,drms(dU,c1)) | |
| def forward(M,toks): | |
| cfg=M['cfg']; x=toks[:-1]; y=toks[1:]; H=[M['E'][t][:] for t in x]; C=[] | |
| for P in M['B']: H,c=block(H,P,cfg); C.append(c) | |
| Z=mm(H,M['E']); P=[]; loss=0.0 | |
| for z,t in zip(Z,y): | |
| m=max(z); e=[math.exp(v-m) for v in z]; s=sum(e); p=[v/s for v in e]; P.append(p); loss+=-(z[t]-m-math.log(s)) | |
| return loss/len(y),(x,y,H,P,C) | |
| def backward(M,C): | |
| x,y,H,P,Cs=C; cfg=M['cfg']; G={'E':zeros(cfg['v'],cfg['d']),'B':[{'Wqkv':zeros(3*cfg['d'],cfg['d']),'Wo':zeros(cfg['d'],cfg['d']),'W1':zeros(cfg['f'],cfg['d']),'W2':zeros(cfg['d'],cfg['f'])} for _ in M['B']]}; D=[] | |
| for h,p,t in zip(H,P,y): | |
| dz=p[:]; dz[t]-=1; dz=[v/len(y) for v in dz]; D.append([sum(dz[j]*M['E'][j][i] for j in range(cfg['v'])) for i in range(cfg['d'])]) | |
| for j,g in enumerate(dz): | |
| if g: | |
| for i,v in enumerate(h): G['E'][j][i]+=g*v | |
| for bi in range(len(M['B'])-1,-1,-1): D=dblock(D,M['B'][bi],cfg,Cs[bi],G['B'][bi]) | |
| for t,dx in zip(x,D): | |
| for i,v in enumerate(dx): G['E'][t][i]+=v | |
| return G | |
| def step(M,G,lr=.02,wd=.0): | |
| for k in ('E',): | |
| for i,row in enumerate(M[k]): | |
| for j,_ in enumerate(row): M[k][i][j]-=lr*(G[k][i][j]+wd*M[k][i][j]) | |
| for b,g in zip(M['B'],G['B']): | |
| for k in b: | |
| for i,row in enumerate(b[k]): | |
| for j,_ in enumerate(row): b[k][i][j]-=lr*(g[k][i][j]+wd*b[k][i][j]) | |
| def sample(M,prompt,n=100,temp=1.0): | |
| toks=[256]+list(prompt.encode()) | |
| for _ in range(n): | |
| _,(x,y,H,P,C)=forward(M,(toks+[256])[-M['cfg']['c']-1:]); z=[math.log(max(1e-9,p)) for p in P[-1]] | |
| if temp!=1: z=[v/temp for v in z] | |
| m=max(z); e=[math.exp(v-m) for v in z]; s=sum(e); r=R.random(); a=0 | |
| for i,p in enumerate([v/s for v in e]): | |
| a+=p | |
| if a>=r: toks.append(i); break | |
| return bytes(t for t in toks if t<256).decode('utf-8','replace') | |
| def train(text,steps=200,d=64,h=4,c=64,lr=.03,ckpt='tiny.pkl'): | |
| M=init(d,h,4*d,257,c); data=[256]+list(text.encode())+[256] | |
| for n in range(1,steps+1): | |
| i=R.randrange(0,max(1,len(data)-c-1)); chunk=data[i:i+c+1]; loss,C=forward(M,chunk); G=backward(M,C); step(M,G,lr,1e-4) | |
| if n%50==0: print(n,round(loss,4)); pickle.dump(M,open(ckpt,'wb')) | |
| return M | |
| if __name__=='__main__': | |
| ap=argparse.ArgumentParser(); ap.add_argument('cmd',choices=['train','sample']); ap.add_argument('--file',default='input.txt'); ap.add_argument('--ckpt',default='microgpt2_tiny.pkl'); ap.add_argument('--steps',type=int,default=200); ap.add_argument('--d',type=int,default=64); ap.add_argument('--h',type=int,default=4); ap.add_argument('--c',type=int,default=64); ap.add_argument('--lr',type=float,default=.03); ap.add_argument('--prompt',default=''); ap.add_argument('--n',type=int,default=200); a=ap.parse_args() | |
| if a.cmd=='train': | |
| text=open(a.file,'r',encoding='utf-8',errors='replace').read(); M=train(text,a.steps,a.d,a.h,a.c,a.lr,a.ckpt); print(sample(M,a.prompt,120)) | |
| else: print(sample(pickle.load(open(a.ckpt,'rb')),a.prompt,a.n)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment