Created
December 9, 2025 07:41
-
-
Save falseywinchnet/698f4ea7acfa7d513650c1e1c0945d3f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| # ------------------------------------------------------------------- | |
| # Config and device | |
| # ------------------------------------------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| VOCAB_SIZE = 64 | |
| D_MODEL = 32 | |
| SEQ_LEN = 32 # implies first half = 16, second half = 16 | |
| NUM_EPOCHS = 20 | |
| BATCH_SIZE = 128 | |
| STEPS_PER_EPOCH = 50 | |
| LR = 1e-3 | |
| assert D_MODEL % 2 == 0, "Model dimension must be even for RoPE/GRAPE." | |
| # ------------------------------------------------------------------- | |
| # Synthetic "copy / associative recall" dataset | |
| # - First half: random tokens | |
| # - Second half: exact copy of first half | |
| # - Objective: next-token prediction | |
| # ------------------------------------------------------------------- | |
| IGNORE_INDEX = -100 | |
| def generate_batch(batch_size, seq_len, vocab_size, device): | |
| half = seq_len // 2 | |
| # First half random | |
| first_half = torch.randint( | |
| low=0, | |
| high=vocab_size, | |
| size=(batch_size, half), | |
| device=device | |
| ) | |
| # Second half copies first half | |
| second_half = first_half.clone() | |
| x = torch.cat([first_half, second_half], dim=1) # (B, seq_len) | |
| # Next-token prediction targets | |
| y = x.roll(shifts=-1, dims=1) | |
| y[:, -1] = IGNORE_INDEX # last position has no next token | |
| return x, y | |
| # ------------------------------------------------------------------- | |
| # RoPE positional encoding for a single head | |
| # ------------------------------------------------------------------- | |
| class RopePositional(nn.Module): | |
| def __init__(self, dim, base=10000.0): | |
| super().__init__() | |
| assert dim % 2 == 0 | |
| self.dim = dim | |
| self.base = base | |
| def forward(self, q, k): | |
| """ | |
| q, k: (B, T, D) | |
| returns q_rot, k_rot: (B, T, D) with RoPE applied | |
| """ | |
| b, t, d = q.shape | |
| device = q.device | |
| half = d // 2 | |
| # Frequency spectrum | |
| idx = torch.arange(half, device=device, dtype=q.dtype) # [0..half-1] | |
| # classic RoPE frequency scaling | |
| freq = self.base ** (-2 * idx / d) # shape (half,) | |
| # Positions | |
| positions = torch.arange(t, device=device, dtype=q.dtype) # (T,) | |
| # Angles: (T, half) | |
| angles = torch.einsum("t,f->tf", positions, freq) | |
| cos = angles.cos()[None, :, :] # (1, T, half) | |
| sin = angles.sin()[None, :, :] # (1, T, half) | |
| def apply_rope(x): | |
| x1, x2 = x[..., :half], x[..., half:] | |
| x_rot1 = x1 * cos - x2 * sin | |
| x_rot2 = x1 * sin + x2 * cos | |
| return torch.cat([x_rot1, x_rot2], dim=-1) | |
| return apply_rope(q), apply_rope(k) | |
| # ------------------------------------------------------------------- | |
| # Multiplicative GRAPE positional encoding (rank-2 generator) | |
| # | |
| # G(n) = exp(n * ω * L) | |
| # L = a b^T - b a^T is skew-symmetric, so exp(n ω L) is a rotation in R^d. | |
| # | |
| # Implemented using torch.matrix_exp for clarity (d=32, seq_len small). | |
| # ------------------------------------------------------------------- | |
| class GrapePositional(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| # Rank 2 generator parameters | |
| self.a = nn.Parameter(torch.randn(dim)) | |
| self.b = nn.Parameter(torch.randn(dim)) | |
| self.log_omega = nn.Parameter(torch.zeros(())) | |
| def forward(self, q, k): | |
| """ | |
| q, k: (B, T, D) | |
| returns q_rot, k_rot with GRAPE applied | |
| """ | |
| bsz, T, d = q.shape | |
| device, dtype = q.device, q.dtype | |
| # Normalize a and b to keep rotation magnitude under control | |
| a = self.a.to(device=device, dtype=dtype) | |
| b = self.b.to(device=device, dtype=dtype) | |
| a = a / (a.norm() + 1e-6) | |
| b = b / (b.norm() + 1e-6) | |
| # Skew symmetric generator L = a b^T − b a^T | |
| L = torch.outer(a, b) - torch.outer(b, a) | |
| omega = torch.exp(self.log_omega) | |
| # Build G(n) = exp(n * ω * L) for n = 0 .. T − 1 | |
| positions = torch.arange(T, device=device, dtype=dtype) | |
| G_list = [] | |
| for n in positions: | |
| G_n = torch.matrix_exp(n * omega * L) # (D, D) | |
| G_list.append(G_n) | |
| G = torch.stack(G_list, dim=0) # (T, D, D) | |
| # Apply rotations: x_t -> x_t @ G_t | |
| # q: (B, T, D), G: (T, D, D) -> (B, T, D) | |
| q_rot = torch.einsum("btd,tdm->btm", q, G) | |
| k_rot = torch.einsum("btd,tdm->btm", k, G) | |
| return q_rot, k_rot | |
| # ------------------------------------------------------------------- | |
| # Simple 1-head causal self-attention with pluggable positional encoding | |
| # ------------------------------------------------------------------- | |
| class SimpleCausalSelfAttention(nn.Module): | |
| def __init__(self, dim, mechanism="rope"): | |
| super().__init__() | |
| self.dim = dim | |
| self.w_q = nn.Linear(dim, dim, bias=False) | |
| self.w_k = nn.Linear(dim, dim, bias=False) | |
| self.w_v = nn.Linear(dim, dim, bias=False) | |
| self.w_o = nn.Linear(dim, dim, bias=False) | |
| if mechanism == "rope": | |
| self.pos_enc = RopePositional(dim) | |
| elif mechanism == "grape": | |
| self.pos_enc = GrapePositional(dim) | |
| else: | |
| raise ValueError(f"Unknown positional mechanism: {mechanism}") | |
| def forward(self, x): | |
| """ | |
| x: (B, T, D) | |
| returns: (B, T, D) | |
| """ | |
| B, T, D = x.shape | |
| q = self.w_q(x) | |
| k = self.w_k(x) | |
| v = self.w_v(x) | |
| # Apply position encoding in Q/K space | |
| q, k = self.pos_enc(q, k) | |
| # Causal attention mask | |
| att_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D) | |
| mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| att_scores = att_scores.masked_fill(mask, float("-inf")) | |
| att_weights = F.softmax(att_scores, dim=-1) | |
| y = torch.matmul(att_weights, v) | |
| y = self.w_o(y) | |
| return y | |
| # ------------------------------------------------------------------- | |
| # Tiny LM: embed → attention → unembed | |
| # ------------------------------------------------------------------- | |
| class SimpleAssociativeLM(nn.Module): | |
| def __init__(self, vocab_size, dim, mechanism="rope"): | |
| super().__init__() | |
| self.token_embed = nn.Embedding(vocab_size, dim) | |
| self.attn = SimpleCausalSelfAttention(dim, mechanism=mechanism) | |
| # Unembedding; tie weights with embedding for simplicity | |
| self.unembed = nn.Linear(dim, vocab_size, bias=False) | |
| self.unembed.weight = self.token_embed.weight | |
| def forward(self, input_ids): | |
| """ | |
| input_ids: (B, T) int64 | |
| returns logits: (B, T, vocab_size) | |
| """ | |
| x = self.token_embed(input_ids) # (B, T, D) | |
| x = self.attn(x) # (B, T, D) | |
| logits = self.unembed(x) # (B, T, V) | |
| return logits | |
| # ------------------------------------------------------------------- | |
| # Evaluation: accuracy on first and second halves | |
| # ------------------------------------------------------------------- | |
| @torch.no_grad() | |
| def evaluate_halves(model, batch_size=256, num_batches=20): | |
| model.eval() | |
| half = SEQ_LEN // 2 | |
| total_first_correct = 0 | |
| total_first_tokens = 0 | |
| total_second_correct = 0 | |
| total_second_tokens = 0 | |
| for _ in range(num_batches): | |
| x, y = generate_batch(batch_size, SEQ_LEN, VOCAB_SIZE, device) | |
| logits = model(x) | |
| preds = logits.argmax(dim=-1) # (B, T) | |
| valid_mask = y != IGNORE_INDEX # (B, T) | |
| # First half: positions [0..half-1] | |
| first_mask = torch.zeros_like(valid_mask, dtype=torch.bool) | |
| first_mask[:, :half] = True | |
| first_mask &= valid_mask | |
| # Second half: positions [half..T-2] (last token ignored) | |
| second_mask = torch.zeros_like(valid_mask, dtype=torch.bool) | |
| second_mask[:, half:] = True | |
| second_mask &= valid_mask | |
| correct = (preds == y) | |
| if first_mask.any(): | |
| total_first_correct += correct[first_mask].sum().item() | |
| total_first_tokens += first_mask.sum().item() | |
| if second_mask.any(): | |
| total_second_correct += correct[second_mask].sum().item() | |
| total_second_tokens += second_mask.sum().item() | |
| first_acc = total_first_correct / max(1, total_first_tokens) | |
| second_acc = total_second_correct / max(1, total_second_tokens) | |
| return first_acc, second_acc | |
| # ------------------------------------------------------------------- | |
| # Instantiate models and optimizers | |
| # ------------------------------------------------------------------- | |
| rope_model = SimpleAssociativeLM(VOCAB_SIZE, D_MODEL, mechanism="rope").to(device) | |
| grape_model = SimpleAssociativeLM(VOCAB_SIZE, D_MODEL, mechanism="grape").to(device) | |
| rope_optimizer = torch.optim.AdamW(rope_model.parameters(), lr=LR) | |
| grape_optimizer = torch.optim.AdamW(grape_model.parameters(), lr=LR) | |
| criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) | |
| # History containers | |
| history = { | |
| "rope_first": [], | |
| "rope_second": [], | |
| "grape_first": [], | |
| "grape_second": [], | |
| } | |
| # ------------------------------------------------------------------- | |
| # Training loop | |
| # ------------------------------------------------------------------- | |
| for epoch in range(1, NUM_EPOCHS + 1): | |
| rope_model.train() | |
| grape_model.train() | |
| for step in range(STEPS_PER_EPOCH): | |
| x, y = generate_batch(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE, device) | |
| # RoPE model step | |
| rope_optimizer.zero_grad(set_to_none=True) | |
| rope_logits = rope_model(x) | |
| rope_loss = criterion(rope_logits.view(-1, VOCAB_SIZE), y.view(-1)) | |
| rope_loss.backward() | |
| rope_optimizer.step() | |
| # GRAPE model step | |
| grape_optimizer.zero_grad(set_to_none=True) | |
| grape_logits = grape_model(x) | |
| grape_loss = criterion(grape_logits.view(-1, VOCAB_SIZE), y.view(-1)) | |
| grape_loss.backward() | |
| grape_optimizer.step() | |
| # Evaluate after each epoch | |
| rope_first, rope_second = evaluate_halves(rope_model) | |
| grape_first, grape_second = evaluate_halves(grape_model) | |
| history["rope_first"].append(rope_first) | |
| history["rope_second"].append(rope_second) | |
| history["grape_first"].append(grape_first) | |
| history["grape_second"].append(grape_second) | |
| print( | |
| f"Epoch {epoch:02d} | " | |
| f"RoPE first={rope_first:.3f}, second={rope_second:.3f} | " | |
| f"GRAPE first={grape_first:.3f}, second={grape_second:.3f}" | |
| ) | |
| # ------------------------------------------------------------------- | |
| # Plot both: accuracy over epochs for first vs second half | |
| # ------------------------------------------------------------------- | |
| epochs = range(1, NUM_EPOCHS + 1) | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(epochs, history["rope_first"], label="RoPE - first half") | |
| plt.plot(epochs, history["rope_second"], label="RoPE - second half") | |
| plt.plot(epochs, history["grape_first"], label="GRAPE - first half") | |
| plt.plot(epochs, history["grape_second"], label="GRAPE - second half") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("Accuracy") | |
| plt.title("Associative recall proxy: first vs second half token prediction") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.tight_layout() | |
| plt.show() | |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Epoch 01 | RoPE first=0.017, second=0.032 | GRAPE first=0.017, second=0.032
Epoch 02 | RoPE first=0.020, second=0.047 | GRAPE first=0.019, second=0.043
Epoch 03 | RoPE first=0.020, second=0.056 | GRAPE first=0.020, second=0.050
Epoch 04 | RoPE first=0.023, second=0.088 | GRAPE first=0.020, second=0.053
Epoch 05 | RoPE first=0.032, second=0.227 | GRAPE first=0.020, second=0.055
Epoch 06 | RoPE first=0.047, second=0.454 | GRAPE first=0.021, second=0.057
Epoch 07 | RoPE first=0.060, second=0.694 | GRAPE first=0.022, second=0.058
Epoch 08 | RoPE first=0.068, second=0.826 | GRAPE first=0.022, second=0.065
Epoch 09 | RoPE first=0.073, second=0.932 | GRAPE first=0.025, second=0.093
Epoch 10 | RoPE first=0.076, second=0.969 | GRAPE first=0.030, second=0.183
Epoch 11 | RoPE first=0.077, second=0.983 | GRAPE first=0.035, second=0.272
Epoch 12 | RoPE first=0.077, second=0.998 | GRAPE first=0.040, second=0.334
Epoch 13 | RoPE first=0.078, second=1.000 | GRAPE first=0.042, second=0.379
Epoch 14 | RoPE first=0.077, second=1.000 | GRAPE first=0.043, second=0.411
Epoch 15 | RoPE first=0.077, second=1.000 | GRAPE first=0.044, second=0.420
Epoch 16 | RoPE first=0.077, second=1.000 | GRAPE first=0.046, second=0.442
Epoch 17 | RoPE first=0.078, second=1.000 | GRAPE first=0.063, second=0.742
Epoch 18 | RoPE first=0.077, second=1.000 | GRAPE first=0.063, second=0.745
Epoch 19 | RoPE first=0.077, second=1.000 | GRAPE first=0.062, second=0.747
Epoch 20 | RoPE first=0.077, second=1.000 | GRAPE first=0.063, second=0.744