Skip to content

Instantly share code, notes, and snippets.

@falseywinchnet
Created December 9, 2025 07:41
Show Gist options
  • Select an option

  • Save falseywinchnet/698f4ea7acfa7d513650c1e1c0945d3f to your computer and use it in GitHub Desktop.

Select an option

Save falseywinchnet/698f4ea7acfa7d513650c1e1c0945d3f to your computer and use it in GitHub Desktop.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
# -------------------------------------------------------------------
# Config and device
# -------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VOCAB_SIZE = 64
D_MODEL = 32
SEQ_LEN = 32 # implies first half = 16, second half = 16
NUM_EPOCHS = 20
BATCH_SIZE = 128
STEPS_PER_EPOCH = 50
LR = 1e-3
assert D_MODEL % 2 == 0, "Model dimension must be even for RoPE/GRAPE."
# -------------------------------------------------------------------
# Synthetic "copy / associative recall" dataset
# - First half: random tokens
# - Second half: exact copy of first half
# - Objective: next-token prediction
# -------------------------------------------------------------------
IGNORE_INDEX = -100
def generate_batch(batch_size, seq_len, vocab_size, device):
half = seq_len // 2
# First half random
first_half = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, half),
device=device
)
# Second half copies first half
second_half = first_half.clone()
x = torch.cat([first_half, second_half], dim=1) # (B, seq_len)
# Next-token prediction targets
y = x.roll(shifts=-1, dims=1)
y[:, -1] = IGNORE_INDEX # last position has no next token
return x, y
# -------------------------------------------------------------------
# RoPE positional encoding for a single head
# -------------------------------------------------------------------
class RopePositional(nn.Module):
def __init__(self, dim, base=10000.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.base = base
def forward(self, q, k):
"""
q, k: (B, T, D)
returns q_rot, k_rot: (B, T, D) with RoPE applied
"""
b, t, d = q.shape
device = q.device
half = d // 2
# Frequency spectrum
idx = torch.arange(half, device=device, dtype=q.dtype) # [0..half-1]
# classic RoPE frequency scaling
freq = self.base ** (-2 * idx / d) # shape (half,)
# Positions
positions = torch.arange(t, device=device, dtype=q.dtype) # (T,)
# Angles: (T, half)
angles = torch.einsum("t,f->tf", positions, freq)
cos = angles.cos()[None, :, :] # (1, T, half)
sin = angles.sin()[None, :, :] # (1, T, half)
def apply_rope(x):
x1, x2 = x[..., :half], x[..., half:]
x_rot1 = x1 * cos - x2 * sin
x_rot2 = x1 * sin + x2 * cos
return torch.cat([x_rot1, x_rot2], dim=-1)
return apply_rope(q), apply_rope(k)
# -------------------------------------------------------------------
# Multiplicative GRAPE positional encoding (rank-2 generator)
#
# G(n) = exp(n * ω * L)
# L = a b^T - b a^T is skew-symmetric, so exp(n ω L) is a rotation in R^d.
#
# Implemented using torch.matrix_exp for clarity (d=32, seq_len small).
# -------------------------------------------------------------------
class GrapePositional(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
# Rank 2 generator parameters
self.a = nn.Parameter(torch.randn(dim))
self.b = nn.Parameter(torch.randn(dim))
self.log_omega = nn.Parameter(torch.zeros(()))
def forward(self, q, k):
"""
q, k: (B, T, D)
returns q_rot, k_rot with GRAPE applied
"""
bsz, T, d = q.shape
device, dtype = q.device, q.dtype
# Normalize a and b to keep rotation magnitude under control
a = self.a.to(device=device, dtype=dtype)
b = self.b.to(device=device, dtype=dtype)
a = a / (a.norm() + 1e-6)
b = b / (b.norm() + 1e-6)
# Skew symmetric generator L = a b^T − b a^T
L = torch.outer(a, b) - torch.outer(b, a)
omega = torch.exp(self.log_omega)
# Build G(n) = exp(n * ω * L) for n = 0 .. T − 1
positions = torch.arange(T, device=device, dtype=dtype)
G_list = []
for n in positions:
G_n = torch.matrix_exp(n * omega * L) # (D, D)
G_list.append(G_n)
G = torch.stack(G_list, dim=0) # (T, D, D)
# Apply rotations: x_t -> x_t @ G_t
# q: (B, T, D), G: (T, D, D) -> (B, T, D)
q_rot = torch.einsum("btd,tdm->btm", q, G)
k_rot = torch.einsum("btd,tdm->btm", k, G)
return q_rot, k_rot
# -------------------------------------------------------------------
# Simple 1-head causal self-attention with pluggable positional encoding
# -------------------------------------------------------------------
class SimpleCausalSelfAttention(nn.Module):
def __init__(self, dim, mechanism="rope"):
super().__init__()
self.dim = dim
self.w_q = nn.Linear(dim, dim, bias=False)
self.w_k = nn.Linear(dim, dim, bias=False)
self.w_v = nn.Linear(dim, dim, bias=False)
self.w_o = nn.Linear(dim, dim, bias=False)
if mechanism == "rope":
self.pos_enc = RopePositional(dim)
elif mechanism == "grape":
self.pos_enc = GrapePositional(dim)
else:
raise ValueError(f"Unknown positional mechanism: {mechanism}")
def forward(self, x):
"""
x: (B, T, D)
returns: (B, T, D)
"""
B, T, D = x.shape
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
# Apply position encoding in Q/K space
q, k = self.pos_enc(q, k)
# Causal attention mask
att_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
att_scores = att_scores.masked_fill(mask, float("-inf"))
att_weights = F.softmax(att_scores, dim=-1)
y = torch.matmul(att_weights, v)
y = self.w_o(y)
return y
# -------------------------------------------------------------------
# Tiny LM: embed → attention → unembed
# -------------------------------------------------------------------
class SimpleAssociativeLM(nn.Module):
def __init__(self, vocab_size, dim, mechanism="rope"):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, dim)
self.attn = SimpleCausalSelfAttention(dim, mechanism=mechanism)
# Unembedding; tie weights with embedding for simplicity
self.unembed = nn.Linear(dim, vocab_size, bias=False)
self.unembed.weight = self.token_embed.weight
def forward(self, input_ids):
"""
input_ids: (B, T) int64
returns logits: (B, T, vocab_size)
"""
x = self.token_embed(input_ids) # (B, T, D)
x = self.attn(x) # (B, T, D)
logits = self.unembed(x) # (B, T, V)
return logits
# -------------------------------------------------------------------
# Evaluation: accuracy on first and second halves
# -------------------------------------------------------------------
@torch.no_grad()
def evaluate_halves(model, batch_size=256, num_batches=20):
model.eval()
half = SEQ_LEN // 2
total_first_correct = 0
total_first_tokens = 0
total_second_correct = 0
total_second_tokens = 0
for _ in range(num_batches):
x, y = generate_batch(batch_size, SEQ_LEN, VOCAB_SIZE, device)
logits = model(x)
preds = logits.argmax(dim=-1) # (B, T)
valid_mask = y != IGNORE_INDEX # (B, T)
# First half: positions [0..half-1]
first_mask = torch.zeros_like(valid_mask, dtype=torch.bool)
first_mask[:, :half] = True
first_mask &= valid_mask
# Second half: positions [half..T-2] (last token ignored)
second_mask = torch.zeros_like(valid_mask, dtype=torch.bool)
second_mask[:, half:] = True
second_mask &= valid_mask
correct = (preds == y)
if first_mask.any():
total_first_correct += correct[first_mask].sum().item()
total_first_tokens += first_mask.sum().item()
if second_mask.any():
total_second_correct += correct[second_mask].sum().item()
total_second_tokens += second_mask.sum().item()
first_acc = total_first_correct / max(1, total_first_tokens)
second_acc = total_second_correct / max(1, total_second_tokens)
return first_acc, second_acc
# -------------------------------------------------------------------
# Instantiate models and optimizers
# -------------------------------------------------------------------
rope_model = SimpleAssociativeLM(VOCAB_SIZE, D_MODEL, mechanism="rope").to(device)
grape_model = SimpleAssociativeLM(VOCAB_SIZE, D_MODEL, mechanism="grape").to(device)
rope_optimizer = torch.optim.AdamW(rope_model.parameters(), lr=LR)
grape_optimizer = torch.optim.AdamW(grape_model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
# History containers
history = {
"rope_first": [],
"rope_second": [],
"grape_first": [],
"grape_second": [],
}
# -------------------------------------------------------------------
# Training loop
# -------------------------------------------------------------------
for epoch in range(1, NUM_EPOCHS + 1):
rope_model.train()
grape_model.train()
for step in range(STEPS_PER_EPOCH):
x, y = generate_batch(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE, device)
# RoPE model step
rope_optimizer.zero_grad(set_to_none=True)
rope_logits = rope_model(x)
rope_loss = criterion(rope_logits.view(-1, VOCAB_SIZE), y.view(-1))
rope_loss.backward()
rope_optimizer.step()
# GRAPE model step
grape_optimizer.zero_grad(set_to_none=True)
grape_logits = grape_model(x)
grape_loss = criterion(grape_logits.view(-1, VOCAB_SIZE), y.view(-1))
grape_loss.backward()
grape_optimizer.step()
# Evaluate after each epoch
rope_first, rope_second = evaluate_halves(rope_model)
grape_first, grape_second = evaluate_halves(grape_model)
history["rope_first"].append(rope_first)
history["rope_second"].append(rope_second)
history["grape_first"].append(grape_first)
history["grape_second"].append(grape_second)
print(
f"Epoch {epoch:02d} | "
f"RoPE first={rope_first:.3f}, second={rope_second:.3f} | "
f"GRAPE first={grape_first:.3f}, second={grape_second:.3f}"
)
# -------------------------------------------------------------------
# Plot both: accuracy over epochs for first vs second half
# -------------------------------------------------------------------
epochs = range(1, NUM_EPOCHS + 1)
plt.figure(figsize=(10, 6))
plt.plot(epochs, history["rope_first"], label="RoPE - first half")
plt.plot(epochs, history["rope_second"], label="RoPE - second half")
plt.plot(epochs, history["grape_first"], label="GRAPE - first half")
plt.plot(epochs, history["grape_second"], label="GRAPE - second half")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Associative recall proxy: first vs second half token prediction")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
@falseywinchnet
Copy link
Author

Epoch 01 | RoPE first=0.017, second=0.032 | GRAPE first=0.017, second=0.032
Epoch 02 | RoPE first=0.020, second=0.047 | GRAPE first=0.019, second=0.043
Epoch 03 | RoPE first=0.020, second=0.056 | GRAPE first=0.020, second=0.050
Epoch 04 | RoPE first=0.023, second=0.088 | GRAPE first=0.020, second=0.053
Epoch 05 | RoPE first=0.032, second=0.227 | GRAPE first=0.020, second=0.055
Epoch 06 | RoPE first=0.047, second=0.454 | GRAPE first=0.021, second=0.057
Epoch 07 | RoPE first=0.060, second=0.694 | GRAPE first=0.022, second=0.058
Epoch 08 | RoPE first=0.068, second=0.826 | GRAPE first=0.022, second=0.065
Epoch 09 | RoPE first=0.073, second=0.932 | GRAPE first=0.025, second=0.093
Epoch 10 | RoPE first=0.076, second=0.969 | GRAPE first=0.030, second=0.183
Epoch 11 | RoPE first=0.077, second=0.983 | GRAPE first=0.035, second=0.272
Epoch 12 | RoPE first=0.077, second=0.998 | GRAPE first=0.040, second=0.334
Epoch 13 | RoPE first=0.078, second=1.000 | GRAPE first=0.042, second=0.379
Epoch 14 | RoPE first=0.077, second=1.000 | GRAPE first=0.043, second=0.411
Epoch 15 | RoPE first=0.077, second=1.000 | GRAPE first=0.044, second=0.420
Epoch 16 | RoPE first=0.077, second=1.000 | GRAPE first=0.046, second=0.442
Epoch 17 | RoPE first=0.078, second=1.000 | GRAPE first=0.063, second=0.742
Epoch 18 | RoPE first=0.077, second=1.000 | GRAPE first=0.063, second=0.745
Epoch 19 | RoPE first=0.077, second=1.000 | GRAPE first=0.062, second=0.747
Epoch 20 | RoPE first=0.077, second=1.000 | GRAPE first=0.063, second=0.744

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment