Skip to content

Instantly share code, notes, and snippets.

@cosminscn
Last active February 26, 2026 18:03
Show Gist options
  • Select an option

  • Save cosminscn/89c110dbae76ea0c873d67607e466f5b to your computer and use it in GitHub Desktop.

Select an option

Save cosminscn/89c110dbae76ea0c873d67607e466f5b to your computer and use it in GitHub Desktop.
"""
Dynamic NanoGPT Adder: 130 params + 0 buffers
transformer.wte.A [10, 1] = 10
transformer.wte.B [1, 4] = 4
transformer.h.0.attn.c_attn.weight [12, 4] = 48
transformer.h.0.attn.c_proj.weight [4, 4] = 16
transformer.h.0.mlp.c_fc.weight [4, 4] = 16
transformer.h.0.mlp.c_fc.bias [4] = 4
transformer.h.0.mlp.c_proj.u [4, 1] = 4
transformer.h.0.mlp.c_proj.v [1, 4] = 4
lm_head.u [10, 1] = 10
lm_head.v [1, 4] = 4
lm_head.bias [10] = 10
TOTAL = 130
============================================================
1. Original hand-picked tests
============================================================
5 + 5 = 10 ✅
555 + 445 = 1,000 ✅
99,999 + 1 = 100,000 ✅
19,492 + 23,919 = 43,411 ✅
9,999,999,999 + 1 = 10,000,000,000 ✅
1,234,567,890 + 987,654,321 = 2,222,222,211 ✅
0 + 0 = 0 ✅
1,111,111,111 + 8,888,888,889 = 10,000,000,000 ✅
8/8 passed
============================================================
2. Boundary / edge cases
============================================================
0 + 0 = 0 ✅
0 + 1 = 1 ✅
1 + 0 = 1 ✅
0 + 9 = 9 ✅
9 + 0 = 9 ✅
9,999,999,999 + 0 = 9,999,999,999 ✅
0 + 9,999,999,999 = 9,999,999,999 ✅
9,999,999,999 + 9,999,999,999 = 19,999,999,998 ✅
5,000,000,000 + 5,000,000,000 = 10,000,000,000 ✅
1 + 9,999,999,999 = 10,000,000,000 ✅
9,999,999,999 + 1 = 10,000,000,000 ✅
11/11 passed
============================================================
3. Carry propagation chains (10^k - 1) + 1
============================================================
9 + 1 = 10 ✅
99 + 1 = 100 ✅
999 + 1 = 1,000 ✅
9,999 + 1 = 10,000 ✅
99,999 + 1 = 100,000 ✅
999,999 + 1 = 1,000,000 ✅
9,999,999 + 1 = 10,000,000 ✅
99,999,999 + 1 = 100,000,000 ✅
999,999,999 + 1 = 1,000,000,000 ✅
9,999,999,999 + 1 = 10,000,000,000 ✅
10/10 passed
============================================================
4. Long carry chains (all 9s + small)
============================================================
9,999,999,999 + 2 = 10,000,000,001 ✅
9,999,999,999 + 9 = 10,000,000,008 ✅
9,999,999,999 + 10 = 10,000,000,009 ✅
9,999,999,999 + 99 = 10,000,000,098 ✅
9,999,999,999 + 100 = 10,000,000,099 ✅
9,999,999,990 + 10 = 10,000,000,000 ✅
9,999,999,900 + 100 = 10,000,000,000 ✅
9,999,999,000 + 1,000 = 10,000,000,000 ✅
9,999,990,000 + 10,000 = 10,000,000,000 ✅
9,999,900,000 + 100,000 = 10,000,000,000 ✅
9,990,000,000 + 10,000,000 = 10,000,000,000 ✅
9,900,000,000 + 100,000,000 = 10,000,000,000 ✅
9,000,000,000 + 1,000,000,000 = 10,000,000,000 ✅
13/13 passed
============================================================
5. No carry cases
============================================================
1,111,111,111 + 2,222,222,222 = 3,333,333,333 ✅
1,234,567,890 + 0 = 1,234,567,890 ✅
1,000,000,000 + 2,000,000,000 = 3,000,000,000 ✅
1,234 + 4,321 = 5,555 ✅
1,010,101,010 + 2,020,202,020 = 3,030,303,030 ✅
4,040,404,040 + 5,050,505,050 = 9,090,909,090 ✅
6/6 passed
============================================================
6. Alternating carry patterns
============================================================
5,050,505,050 + 5,050,505,050 = 10,101,010,100 ✅
9,090,909,090 + 1,010,101,010 = 10,101,010,100 ✅
1,919,191,919 + 8,181,818,181 = 10,101,010,100 ✅
5,959,595,959 + 4,141,414,141 = 10,101,010,100 ✅
8,282,828,282 + 2,828,282,828 = 11,111,111,110 ✅
5/5 passed
============================================================
7. Powers of 10
============================================================
64/64 passed
============================================================
8. Palindromic numbers
============================================================
1,234,554,321 + 1,234,554,321 = 2,469,108,642 ✅
9,876,556,789 + 123,443,210 = 9,999,999,999 ✅
1,111,111,111 + 9,999,999,999 = 11,111,111,110 ✅
5,678,876,543 + 4,321,123,456 = 9,999,999,999 ✅
4/4 passed
============================================================
9. Commutativity: a+b == b+a
============================================================
Commutative: 50/50, Correct: 50/50
============================================================
10. Random small numbers (0-99)
============================================================
225/225 passed
============================================================
11. Random medium (1000-999999)
============================================================
500/500 passed
============================================================
12. Large random stress test (10,000 cases)
============================================================
Progress: 10,000/10,000 — 0 failures so far
10000/10000 passed (0 failures)
============================================================
13. Single-digit-position carry tests
============================================================
20/20 passed
============================================================
14. Previously-failing cases
============================================================
9,999,999,999 + 0 = 9,999,999,999 ✅
0 + 9,999,999,999 = 9,999,999,999 ✅
4,040,404,040 + 5,050,505,050 = 9,090,909,090 ✅
9,876,556,789 + 123,443,210 = 9,999,999,999 ✅
5,678,876,543 + 4,321,123,456 = 9,999,999,999 ✅
863,512 + 986,307 = 1,849,819 ✅
642,029 + 250,362 = 892,391 ✅
210,853 + 732,099 = 942,952 ✅
93,028 + 100,535 = 193,563 ✅
112,648 + 286,132 = 398,780 ✅
10/10 passed
============================================================
GRAND TOTAL: 10926/10926 passed
============================================================
"""
"""130-param nanoGPT that adds any two 10-digit numbers.
Combines hand-coded weights with Rank-1, Factorized Embedding tricks, and dynamic PE.
FIX: Wrap-detection neuron biases changed from -9005/-9015 to -9045/-9055.
Mechanism at output position z_i:
dim0 = z_{i-1} (previous output digit, from token embedding)
dim1 = a_{i-1}+b_{i-1} (previous pair sum, from attention head 1 via o_proj)
dim3 = a_i+b_i (current pair sum, from attention head 0 via o_proj)
diff = dim1 - dim0 = (a_{i-1}+b_{i-1}) - z_{i-1} in {-1, 0, 9, 10}
carry_in = I[diff >= 9]
Wrap signal = 1000*(a_i+b_i) + 10*diff
Max no-wrap: 9000 (sum=9, diff=0)
Min wrap: 9090 (sum=9, diff=9)
Gap = 90
Original -9005: only 5 margin below 9000 -> float noise causes false carries.
Fixed -9045: 45 margin below, 35 above. Robust.
"""
import math
import random
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
# === Trick 1: Factorized Embeddings ===
class FactorizedEmbedding(nn.Module):
def __init__(self, vocab_size, emb_dim, rank=1):
super().__init__()
self.A = nn.Parameter(torch.zeros(vocab_size, rank))
self.B = nn.Parameter(torch.zeros(rank, emb_dim))
def forward(self, x):
return self.A[x] @ self.B
# === Trick 2: Rank-1 Linear Layers ===
class Rank1Linear(nn.Module):
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.u = nn.Parameter(torch.zeros(out_features, 1))
self.v = nn.Parameter(torch.zeros(1, in_features))
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('bias', None)
def forward(self, x):
out = (x @ self.v.T) @ self.u.T
if self.bias is not None:
out += self.bias
return out
# === NanoGPT Core ===
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.n_head, self.n_embd, self.dropout = config.n_head, config.n_embd, config.dropout
def forward(self, x):
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
h = config.mlp_hidden or 4 * config.n_embd
self.c_fc = nn.Linear(config.n_embd, h, bias=True)
self.gelu = nn.GELU()
self.c_proj = Rank1Linear(h, config.n_embd, bias=config.bias)
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.Identity()
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.Identity()
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
return x + self.mlp(self.ln_2(x))
@dataclass
class GPTConfig:
block_size: int = 35
vocab_size: int = 10
n_layer: int = 1
n_head: int = 2
n_embd: int = 4
dropout: float = 0.0
bias: bool = False
mlp_hidden: int = 4
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = FactorizedEmbedding(config.vocab_size, config.n_embd, rank=1),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
))
self.lm_head = Rank1Linear(config.n_embd, config.vocab_size, bias=True)
def generate_pe(self, seq_len, device):
pe = torch.zeros(seq_len, self.config.n_embd, device=device)
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
th = 2 * math.pi / 11
amp = torch.where(positions <= 21, 100.0, 1.0)
pe[:, 1] = amp * torch.sin(positions * th)
pe[:, 2] = amp * torch.cos(positions * th)
return pe
def forward(self, idx, targets=None):
seq_len = idx.size(1)
x = self.transformer.wte(idx) + self.generate_pe(seq_len, idx.device)
for block in self.transformer.h:
x = block(x)
return self.lm_head(x[:, [-1], :]), None
# === Weight injection ===
def build_adder():
config = GPTConfig()
model = GPT(config)
th, S, n = 2*math.pi/11, 100.0, 4
with torch.no_grad():
model.transformer.wte.A.zero_()
model.transformer.wte.B.zero_()
for v in range(10): model.transformer.wte.A[v, 0] = float(v)
model.transformer.wte.B[0, 0] = 1.0
model.transformer.h[0].mlp.gelu = nn.ReLU()
w = torch.zeros(3*n, n)
w[0,1] = -math.cos(8*th)*S; w[0,2] = math.sin(8*th)*S
w[1,1] = math.sin(8*th)*S; w[1,2] = math.cos(8*th)*S
w[2,1] = -math.cos(9*th)*S; w[2,2] = math.sin(9*th)*S
w[3,1] = math.sin(9*th)*S; w[3,2] = math.cos(9*th)*S
w[4,1] = w[6,1] = 1.0; w[5,2] = w[7,2] = 1.0
w[8,0] = w[10,0] = 1.0
model.transformer.h[0].attn.c_attn.weight.copy_(w)
w = torch.zeros(n, n)
w[3,0] = 2.0; w[1,2] = 2.0
model.transformer.h[0].attn.c_proj.weight.copy_(w)
fw, fb = torch.zeros(4, n), torch.zeros(4)
fw[0,1]= 100; fw[0,0]=-100; fb[0]= -50
fw[1,1]= 100; fw[1,0]=-100; fb[1]=-150
fw[2,3]=1000; fw[2,1]= 10; fw[2,0]=-10; fb[2]=-9045 # FIX: was -9005
fw[3,3]=1000; fw[3,1]= 10; fw[3,0]=-10; fb[3]=-9055 # FIX: was -9015
model.transformer.h[0].mlp.c_fc.weight.copy_(fw)
model.transformer.h[0].mlp.c_fc.bias.copy_(fb)
model.transformer.h[0].mlp.c_proj.u.zero_()
model.transformer.h[0].mlp.c_proj.v.zero_()
model.transformer.h[0].mlp.c_proj.u[3, 0] = 1.0
model.transformer.h[0].mlp.c_proj.v[0, :] = torch.tensor([0.01, -0.01, -1.0, 1.0])
model.lm_head.u.zero_()
model.lm_head.v.zero_()
model.lm_head.bias.zero_()
for v in range(10):
model.lm_head.u[v, 0] = 2.0 * v
model.lm_head.bias[v] = -float(v * v)
model.lm_head.v[0, 3] = 1.0
return model
# === Inference ===
def test(model, a, b, verbose=True):
tok = {'+': 0, '=': 0}
seq = [int(c) if c.isdigit() else tok[c] for c in f"{a:010d}+{b:010d}="]
model.eval()
with torch.no_grad():
for _ in range(11):
logits, _ = model(torch.tensor([seq]))
seq.append(logits[0,-1].argmax().item())
result = int("".join(str(t) for t in seq[22:])[::-1])
ok = result == a + b
if verbose:
print(f" {a:>13,d} + {b:>13,d} = {result:>13,d} {'✅' if ok else '❌'}")
return ok, result
def test_batch(model, pairs):
tok = {'+': 0, '=': 0}
seqs = []
for a, b in pairs:
seqs.append([int(c) if c.isdigit() else tok[c] for c in f"{a:010d}+{b:010d}="])
model.eval()
with torch.no_grad():
for _ in range(11):
logits, _ = model(torch.tensor(seqs))
next_digits = logits[:, -1].argmax(dim=-1).tolist()
for seq, d in zip(seqs, next_digits):
seq.append(d)
results = []
for seq in seqs:
results.append(int("".join(str(t) for t in seq[22:])[::-1]))
return results
# === Comprehensive tests ===
def run_tests(model):
params = sum(p.numel() for p in model.parameters())
buf = sum(b.numel() for b in model.buffers())
print(f"\n Dynamic NanoGPT Adder: {params} params + {buf} buffers\n")
for name, p in model.named_parameters():
print(f" {name:40s} {str(list(p.shape)):>10s} = {p.numel():>3d}")
print(f" {'TOTAL':40s} {'':>10s} = {params:>3d}\n")
all_passed = 0
all_total = 0
# --- 1. Original hand-picked tests ---
print("=" * 60)
print("1. Original hand-picked tests")
print("=" * 60)
original = [
(5, 5), (555, 445), (99999, 1), (19492, 23919),
(9999999999, 1), (1234567890, 987654321),
(0, 0), (1111111111, 8888888889),
]
passed = sum(test(model, a, b)[0] for a, b in original)
print(f" {passed}/{len(original)} passed\n")
all_passed += passed; all_total += len(original)
# --- 2. Edge cases ---
print("=" * 60)
print("2. Boundary / edge cases")
print("=" * 60)
edge = [
(0, 0), (0, 1), (1, 0), (0, 9), (9, 0),
(9999999999, 0), (0, 9999999999),
(9999999999, 9999999999),
(5000000000, 5000000000),
(1, 9999999999),
(9999999999, 1),
]
passed = sum(test(model, a, b)[0] for a, b in edge)
print(f" {passed}/{len(edge)} passed\n")
all_passed += passed; all_total += len(edge)
# --- 3. Carry chains ---
print("=" * 60)
print("3. Carry propagation chains (10^k - 1) + 1")
print("=" * 60)
carry_chain = [(10**k - 1, 1) for k in range(1, 11)]
passed = sum(test(model, a, b)[0] for a, b in carry_chain)
print(f" {passed}/{len(carry_chain)} passed\n")
all_passed += passed; all_total += len(carry_chain)
# --- 4. Long carry chains ---
print("=" * 60)
print("4. Long carry chains (all 9s + small)")
print("=" * 60)
nines = [
(9999999999, 2), (9999999999, 9), (9999999999, 10),
(9999999999, 99), (9999999999, 100),
(9999999990, 10), (9999999900, 100), (9999999000, 1000),
(9999990000, 10000), (9999900000, 100000),
(9990000000, 10000000), (9900000000, 100000000),
(9000000000, 1000000000),
]
passed = sum(test(model, a, b)[0] for a, b in nines)
print(f" {passed}/{len(nines)} passed\n")
all_passed += passed; all_total += len(nines)
# --- 5. No carry ---
print("=" * 60)
print("5. No carry cases")
print("=" * 60)
no_carry = [
(1111111111, 2222222222),
(1234567890, 0),
(1000000000, 2000000000),
(1234, 4321),
(1010101010, 2020202020),
(4040404040, 5050505050),
]
passed = sum(test(model, a, b)[0] for a, b in no_carry)
print(f" {passed}/{len(no_carry)} passed\n")
all_passed += passed; all_total += len(no_carry)
# --- 6. Alternating carry ---
print("=" * 60)
print("6. Alternating carry patterns")
print("=" * 60)
alternating = [
(5050505050, 5050505050),
(9090909090, 1010101010),
(1919191919, 8181818181),
(5959595959, 4141414141),
(8282828282, 2828282828),
]
passed = sum(test(model, a, b)[0] for a, b in alternating)
print(f" {passed}/{len(alternating)} passed\n")
all_passed += passed; all_total += len(alternating)
# --- 7. Powers of 10 ---
print("=" * 60)
print("7. Powers of 10")
print("=" * 60)
powers = [(10**i, 10**j) for i in range(10) for j in range(10) if i + j <= 10]
results = test_batch(model, powers)
passed = sum(1 for (a, b), r in zip(powers, results) if r == a + b)
failures = [(a, b, r) for (a, b), r in zip(powers, results) if r != a + b]
if failures:
for a, b, r in failures[:5]:
print(f" FAIL: {a} + {b} = {r} (expected {a+b})")
print(f" {passed}/{len(powers)} passed\n")
all_passed += passed; all_total += len(powers)
# --- 8. Palindromes ---
print("=" * 60)
print("8. Palindromic numbers")
print("=" * 60)
palindromes = [
(1234554321, 1234554321),
(9876556789, 123443210),
(1111111111, 9999999999),
(5678876543, 4321123456),
]
passed = sum(test(model, a, b)[0] for a, b in palindromes)
print(f" {passed}/{len(palindromes)} passed\n")
all_passed += passed; all_total += len(palindromes)
# --- 9. Commutativity ---
print("=" * 60)
print("9. Commutativity: a+b == b+a")
print("=" * 60)
rng = random.Random(99)
comm_pairs = [(rng.randint(0, 10**10-1), rng.randint(0, 10**10-1)) for _ in range(50)]
forward_res = test_batch(model, comm_pairs)
backward_res = test_batch(model, [(b, a) for a, b in comm_pairs])
comm_ok = sum(1 for f, b in zip(forward_res, backward_res) if f == b)
correct = sum(1 for (a, b), r in zip(comm_pairs, forward_res) if r == a + b)
print(f" Commutative: {comm_ok}/50, Correct: {correct}/50\n")
all_passed += correct; all_total += 50
# --- 10. Small numbers ---
print("=" * 60)
print("10. Random small numbers (0-99)")
print("=" * 60)
small_pairs = [(a, b) for a in range(0, 100, 7) for b in range(0, 100, 7)]
results = test_batch(model, small_pairs)
passed = sum(1 for (a, b), r in zip(small_pairs, results) if r == a + b)
failures = [(a, b, r) for (a, b), r in zip(small_pairs, results) if r != a + b]
if failures:
for a, b, r in failures[:5]:
print(f" FAIL: {a} + {b} = {r} (expected {a+b})")
print(f" {passed}/{len(small_pairs)} passed\n")
all_passed += passed; all_total += len(small_pairs)
# --- 11. Medium numbers ---
print("=" * 60)
print("11. Random medium (1000-999999)")
print("=" * 60)
rng = random.Random(77)
med_pairs = [(rng.randint(1000, 999999), rng.randint(1000, 999999)) for _ in range(500)]
results = test_batch(model, med_pairs)
passed = sum(1 for (a, b), r in zip(med_pairs, results) if r == a + b)
failures = [(a, b, r) for (a, b), r in zip(med_pairs, results) if r != a + b]
if failures:
for a, b, r in failures[:5]:
print(f" FAIL: {a} + {b} = {r} (expected {a+b})")
print(f" {passed}/{len(med_pairs)} passed\n")
all_passed += passed; all_total += len(med_pairs)
# --- 12. Large stress test ---
print("=" * 60)
print("12. Large random stress test (10,000 cases)")
print("=" * 60)
rng = random.Random(42)
batch_size = 512
total, total_passed, total_failed = 10000, 0, 0
first_failures = []
for start in range(0, total, batch_size):
cur = min(batch_size, total - start)
pairs = [(rng.randint(0, 10**10-1), rng.randint(0, 10**10-1)) for _ in range(cur)]
results = test_batch(model, pairs)
for (a, b), r in zip(pairs, results):
if r == a + b:
total_passed += 1
else:
total_failed += 1
if len(first_failures) < 10:
first_failures.append((a, b, r, a + b))
done = start + cur
if done % 2000 == 0 or done == total:
print(f" Progress: {done:,d}/{total:,d} — {total_failed} failures so far")
if first_failures:
print(f"\n First failures:")
for a, b, got, exp in first_failures:
print(f" {a:010d} + {b:010d} = {got} (expected {exp})")
print(f"\n {total_passed}/{total} passed ({total_failed} failures)\n")
all_passed += total_passed; all_total += total
# --- 13. Digit-position carry ---
print("=" * 60)
print("13. Single-digit-position carry tests")
print("=" * 60)
digit_carry = []
for pos in range(10):
a = 9 * (10 ** pos)
b = 1 * (10 ** pos)
digit_carry.append((a, b))
a = int("9" * (pos + 1))
b = 1
digit_carry.append((a, b))
results = test_batch(model, digit_carry)
passed = sum(1 for (a, b), r in zip(digit_carry, results) if r == a + b)
failures = [(a, b, r) for (a, b), r in zip(digit_carry, results) if r != a + b]
if failures:
for a, b, r in failures[:10]:
print(f" FAIL: {a} + {b} = {r} (expected {a+b})")
print(f" {passed}/{len(digit_carry)} passed\n")
all_passed += passed; all_total += len(digit_carry)
# --- 14. Previously-failing cases ---
print("=" * 60)
print("14. Previously-failing cases")
print("=" * 60)
prev_fail = [
(9999999999, 0),
(0, 9999999999),
(4040404040, 5050505050),
(9876556789, 123443210),
(5678876543, 4321123456),
(863512, 986307),
(642029, 250362),
(210853, 732099),
(93028, 100535),
(112648, 286132),
]
passed = sum(test(model, a, b)[0] for a, b in prev_fail)
print(f" {passed}/{len(prev_fail)} passed\n")
all_passed += passed; all_total += len(prev_fail)
# --- Summary ---
print("=" * 60)
print(f"GRAND TOTAL: {all_passed}/{all_total} passed")
print("=" * 60)
if __name__ == "__main__":
model = build_adder()
run_tests(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment