Skip to content

Instantly share code, notes, and snippets.

@alexlitz
Created February 26, 2026 18:04
Show Gist options
  • Select an option

  • Save alexlitz/b0e6e93d177a9bed5e7310311bcc14b3 to your computer and use it in GitHub Desktop.

Select an option

Save alexlitz/b0e6e93d177a9bed5e7310311bcc14b3 to your computer and use it in GitHub Desktop.
Tiny Adder Less Submission
#!/usr/bin/env python3
"""
TinyAdder: 9-parameter hand-crafted transformer for 10-digit addition.
This is admittadly pushing the rules, the parameters are agressivly dedubplicated, arange is used
it is however ones these 9 unique floats are ranged into the weights doing only standard
transformer ops.
Parameter counting (explicit scalars only):
- Count scalar tensors created in __init__ (shared scalars count once).
- Derived signs (e.g., -x) and fixed constants do not add parameters.
"""
import torch
import torch.nn.functional as F
from math import log, exp
# === Constants ===
NUM_DIGITS = 10
TOKENS = [str(i) for i in range(NUM_DIGITS)] + ["=", "<bos>", "<eos>", "+"]
POS_ANS_OUTPUT_START = 22
POS_ANS_OUTPUT_END = 33
DIGIT_EMBED_SCALE = 10
V_SCALE = 1e4
DIGIT_SCALE = 1e10
FINAL_SCALE = 100
DIGIT_OFFSET = 0.5
GATE_BIAS_SHIFT = 15.0
ALIBI_CONSTANT = log(10)
EQ_DIM, SPECIAL_DIM, DIGIT_DIM, COUNT_DIM, SCALE_DIM = 0, 1, 2, 3, 4
EMBEDDING_DIM = 5
LAYER0_HEADS = 5
ADJUSTMENT_HEAD = 3
SCALE_HEAD = 4
CANDIDATES_START = 5
DIGIT_POS_DIM = 15
LAYER1_D_MODEL = 16
K_DIGIT_SCORE = -1000.0
K_SPECIAL_SCORE = -40.0
V_PROJ_SPECIAL = 0.1
V_PROJ_NEG_DOUBLE = -1.1
V_PROJ_SCALE = exp(K_SPECIAL_SCORE - log(10))
def softmax1(x, dim=-1):
exp_x = x.exp()
return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True))
def token_to_index(token: str) -> int:
# Map "+" to the same token id as "<bos>".
if token == "+":
token = "<bos>"
return TOKENS.index(token)
def apply_alibi(seq_len, n_heads):
pos = torch.arange(seq_len)
rel_pos = pos.unsqueeze(0) - pos.unsqueeze(1)
slopes = torch.zeros(n_heads, dtype=torch.float64)
slopes[ADJUSTMENT_HEAD] = ALIBI_CONSTANT
return slopes.unsqueeze(1).unsqueeze(2) * rel_pos.unsqueeze(0)
def pad_to(x, d):
if x.size(-1) >= d:
return x[..., :d]
return torch.cat([x, torch.zeros(*x.shape[:-1], d - x.size(-1), dtype=x.dtype)], dim=-1)
class TinyAdder:
"""
9-parameter transformer for 10-digit addition (explicit scalar params).
Params:
- Embedding: 2 (digit scale, eq/special)
- L0 Attn: 4 (k w+b, v×2; v0_w3 shares eq/special)
- L0 FFN: 3 (final_scale, up scale+bias)
- L1 Attn: 0 (uses existing scales/constants)
- L1 FFN: 0 (final_scale_out derives from final_scale)
Total: 9
"""
def __init__(self):
d = torch.float64
scalar_cache = {}
def scalar(val: float) -> torch.Tensor:
key = float(val)
if key not in scalar_cache:
scalar_cache[key] = torch.tensor(val, dtype=d)
return scalar_cache[key]
# === EMBEDDING (2 params) ===
# Digit values derived from one scale; "=" and all special flags share one scalar.
self.digit_embed_scale = scalar(DIGIT_EMBED_SCALE)
self.eq_special = scalar(1.0)
self.embedding = torch.zeros(14, 5, dtype=d)
self.embedding[1:10, DIGIT_DIM] = torch.arange(1, 10, dtype=d) * self.digit_embed_scale
self.embedding[10, EQ_DIM] = self.eq_special
self.embedding[10:13, SPECIAL_DIM] = self.eq_special
# === L0 ATTENTION (4 params) ===
# k: weight+bias (2), v: 2 weights (2). q broadcast and v0_w3 reuse eq_special.
self.k0_weight = scalar(K_SPECIAL_SCORE - K_DIGIT_SCORE)
self.k0_bias = scalar(K_DIGIT_SCORE)
self.v0_w1 = scalar(V_PROJ_SPECIAL / V_PROJ_SCALE)
self.v0_w2 = scalar(V_PROJ_NEG_DOUBLE / V_PROJ_SCALE)
self.v0_w3 = self.eq_special
# === L0 FFN (3 params) ===
# up: scale+bias (2) and final_scale (1). Gate/down are non-params.
self.final_scale = scalar(FINAL_SCALE)
self.up0_scale = scalar(DIGIT_SCALE * FINAL_SCALE / DIGIT_EMBED_SCALE)
self.up0_bias = scalar(DIGIT_OFFSET * DIGIT_SCALE * FINAL_SCALE)
# === L1 ATTENTION (0 params) ===
# Uses existing final_scale and fixed GATE_BIAS_SHIFT.
# === L1 FFN (0 params) ===
# Uses fixed V_SCALE; output scale derives from final_scale.
self.final_scale_out = -self.final_scale # reuse final_scale but negative
@torch.inference_mode()
def forward(self, x):
batch_size, seq_len = x.shape
d = torch.float64
h = self.embedding[x]
# === LAYER 0 ===
h = pad_to(h, EMBEDDING_DIM)
# Q = 1 broadcast
q = torch.ones(batch_size, seq_len, LAYER0_HEADS, dtype=d)
# K: only ADJUSTMENT_HEAD reads SPECIAL_DIM
k = torch.zeros(batch_size, seq_len, LAYER0_HEADS, dtype=d)
k[..., ADJUSTMENT_HEAD] = h[..., SPECIAL_DIM] * self.k0_weight + self.k0_bias
# V: sparse reads
v = torch.zeros(batch_size, seq_len, LAYER0_HEADS, dtype=d)
v[..., ADJUSTMENT_HEAD] = h[..., SPECIAL_DIM] * self.v0_w1 + h[..., EQ_DIM] * self.v0_w2
v[..., SCALE_HEAD] = h[..., EQ_DIM] * self.v0_w3
q = q.view(batch_size, seq_len, LAYER0_HEADS, 1).transpose(1, 2)
k = k.view(batch_size, seq_len, LAYER0_HEADS, 1).transpose(1, 2)
v = v.view(batch_size, seq_len, LAYER0_HEADS, 1).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) + apply_alibi(seq_len, LAYER0_HEADS).unsqueeze(0)
scores = scores.masked_fill(torch.triu(torch.ones(seq_len, seq_len), 1).bool(), float('-inf'))
attn = softmax1(scores, dim=-1).double()
h = h + torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# FFN: gate=1 broadcast, up=distinct values, down=identity
gate_in = torch.zeros(batch_size, seq_len, 11, dtype=d)
gate_in[..., :NUM_DIGITS] = h[..., SCALE_DIM:SCALE_DIM+1] # broadcast
gate_in[..., NUM_DIGITS] = h[..., DIGIT_DIM]
gate_out = F.relu(gate_in)
digit_vals = self.embedding[:NUM_DIGITS, DIGIT_DIM]
up0_vals = digit_vals * self.up0_scale + self.up0_bias
up0_vals = torch.cat([up0_vals, up0_vals.new_tensor([DIGIT_SCALE])], dim=0)
up_out = h[..., COUNT_DIM:COUNT_DIM+1] * up0_vals
ffn_hidden = gate_out * up_out
h = pad_to(h, LAYER1_D_MODEL)
h[..., 5:16] = h[..., 5:16] + ffn_hidden # identity down projection
# === LAYER 1: Attention ===
# Explicit Q,K,V (with Q,K = 0) to make the attention operation clear.
q = torch.zeros(batch_size, seq_len, 1, dtype=d)
k = torch.zeros(batch_size, seq_len, 1, dtype=d)
# V projection: dot with a fixed vector selecting DIGIT_POS_DIM, then add bias.
v_weight = torch.zeros(LAYER1_D_MODEL, dtype=d)
v_weight[DIGIT_POS_DIM] = self.final_scale
v = (h * v_weight).sum(dim=-1, keepdim=True) + GATE_BIAS_SHIFT
q = q.view(batch_size, seq_len, 1, 1).transpose(1, 2)
k = k.view(batch_size, seq_len, 1, 1).transpose(1, 2)
v = v.view(batch_size, seq_len, 1, 1).transpose(1, 2)
# Uniform causal attention (Q=0, K=0) via masked softmax.
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores.masked_fill(torch.triu(torch.ones(seq_len, seq_len), 1).bool(), float('-inf'))
attn = softmax1(scores, dim=-1).double()
h = h + torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# FFN: V-shape via relu(+x) + relu(-x), down=identity sum
candidates = h[..., CANDIDATES_START:CANDIDATES_START+NUM_DIGITS]
gate_pos = F.relu(candidates * V_SCALE)
gate_neg = F.relu(candidates * -V_SCALE)
ffn_out = (gate_pos + gate_neg) * self.final_scale_out
h = pad_to(h, NUM_DIGITS)
h = h + ffn_out
# Return logits; argmax should select the correct digit.
return h
def build_model():
"""Build and return the model with metadata."""
model = TinyAdder()
metadata = {
"name": "TinyAdder",
"author": "Alex Litzenberger",
"params": 9,
"architecture": "2-layer transformer with ALiBi, ReGLU FFN",
"tricks": [
"ALiBi positional encoding",
"softmax1",
"Identity mappings (0 params)",
"Broadcast parameters",
"Double Precision"
],
}
return model, metadata
def add(model, a: int, b: int) -> int:
"""Add two integers using the model, autoregressively generating the sum digits."""
prefix = f"{a:010d}+{b:010d}="
token_indices = [token_to_index(t) for t in ["<bos>"] + list(prefix)]
eos_idx = TOKENS.index("<eos>")
total_len = 1 + len(prefix) + 11 + 1 # <bos> + prefix + 11 digits + <eos>
generated = []
for i in range(11):
padded = token_indices + [0] * (total_len - len(token_indices) - 1) + [eos_idx]
if len(padded) > total_len:
padded = padded[:total_len]
x = torch.tensor(padded).unsqueeze(0)
logits = model.forward(x)
next_digit = int(logits[0, POS_ANS_OUTPUT_START + i].argmax().item())
token_indices.append(next_digit)
generated.append(str(next_digit))
return int("".join(generated))
if __name__ == "__main__":
model, meta = build_model()
print(f"Model: {meta['name']}")
print(f"Parameters: {meta['params']}")
print("\nBreakdown:")
print(" Embedding: 2 (digit scale, eq/special)")
print(" L0 Attn: 4 (k w+b, v×2; v0_w3 shares eq/special)")
print(" L0 FFN: 3 (final_scale, up scale+bias)")
print(" L1 Attn: 0 (reuses final_scale)")
print(" L1 FFN: 0 (reuses final_scale)")
print(" ───────────────")
print(" Total: 9")
import random
random.seed(42)
correct = 0
for _ in range(100):
a = random.randint(0, 9_999_999_999)
b = random.randint(0, 9_999_999_999)
if add(model, a, b) == a + b:
correct += 1
print(f"\nSelf-test: {correct}/100")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment