Skip to content

Instantly share code, notes, and snippets.

@falseywinchnet
Created December 3, 2025 03:44
Show Gist options
  • Select an option

  • Save falseywinchnet/abded5e949bd63d1c2d4473965f27fbe to your computer and use it in GitHub Desktop.

Select an option

Save falseywinchnet/abded5e949bd63d1c2d4473965f27fbe to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
# ==========================================
# 1. Dataset: The "Copy" Task (Induction)
# ==========================================
def generate_copy_batch(batch_size, seq_len, vocab_size):
"""
Generates a sequence where the second half is a copy of the first half.
Task: Predict the next token.
This strictly tests 'Induction Heads' (finding the previous instance of current token).
"""
# Ensure even sequence length
if seq_len % 2 != 0: seq_len += 1
half = seq_len // 2
# Generate random data [B, Seq]
data = torch.randint(0, vocab_size, (batch_size, seq_len))
# Force copy: second half = first half
data[:, half:] = data[:, :half]
# Input is sequence, Target is shifted by 1
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
# ==========================================
# 2. Autograder's RoPE Implementation
# ==========================================
class RoPE(nn.Module):
def __init__(self, dim, max_len=4096):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
# x shape: [B, H, T, D]
# shape[2] is T
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
# ==========================================
# 3. Triple Product Attention (Exterior Dot Product)
# ==========================================
class TripleProductAttention(nn.Module):
def __init__(self, d_model, n_head, max_len=128):
super().__init__()
assert d_model % n_head == 0
self.d_head = d_model // n_head
self.n_head = n_head
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_r = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
# RoPE is mandatory and applied to everything (Q, K, R)
self.rope = RoPE(self.d_head, max_len)
# Scale for volume-based attention
# Since Det is degree 6 (product of 3 squared terms essentially),
# and standard attention is degree 2 (Q*K), we need aggressive scaling.
self.scale = 1.0 / (self.d_head ** 1.5)
def forward(self, x):
# x: [B, T, C]
B, T, C = x.shape
# 1. Projections -> [B, T, H, D] -> Transpose to [B, H, T, D]
q = self.W_q(x).view(B, T, self.n_head, self.d_head).transpose(1, 2)
k = self.W_k(x).view(B, T, self.n_head, self.d_head).transpose(1, 2)
r = self.W_r(x).view(B, T, self.n_head, self.d_head).transpose(1, 2)
v = self.W_v(x).view(B, T, self.n_head, self.d_head).transpose(1, 2)
# 2. Apply RoPE (On Everything: Q, K, R)
q = self.rope(q)
k = self.rope(k)
r = self.rope(r)
# 3. Exterior Dot Product Logic (Q^K^R)
# We calculate the Squared Volume of the parallelepiped (Q, K, R)
# This is the Determinant of the Gram Matrix G.
# G = [[q.q, q.k, q.r], [k.q, k.k, k.r], [r.q, r.k, r.r]]
# We want to detect dependency (Zero Volume) -> High Score.
# Score = -Det(G)
# Compute dot product terms
# Q aligned at index i, K aligned at index j, R aligned at index i (reference)
# Norms (Diagonal terms) [B, H, T, 1]
d_qq = (q * q).sum(dim=-1, keepdim=True)
d_rr = (r * r).sum(dim=-1, keepdim=True)
d_rq = (r * q).sum(dim=-1, keepdim=True) # Also a scalar per query
# Key terms need transpose to broadcast over j
d_kk = (k * k).sum(dim=-1, keepdim=True).transpose(-2, -1) # [B, H, 1, T]
# Interaction terms [B, H, T, T]
d_qk = q @ k.transpose(-2, -1) # Q_i . K_j
d_rk = r @ k.transpose(-2, -1) # R_i . K_j
# Determinant of 3x3 Gram Matrix
# det = A(EI - FH) - B(DI - FG) + C(DH - EG) ... simplified Sarrus:
# det = qq*kk*rr + 2*qk*kr*rq - qq*kr^2 - kk*rq^2 - rr*qk^2
# Broadcast all to [B, H, T, T]
term1 = d_qq * d_kk * d_rr
term2 = 2.0 * d_qk * d_rk * d_rq
term3 = - (d_qq * (d_rk ** 2))
term4 = - (d_kk * (d_rq ** 2))
term5 = - (d_rr * (d_qk ** 2))
gram_det = term1 + term2 + term3 + term4 + term5
# Score is negative volume (Dependency Search)
scores = -gram_det * self.scale
# Causal Mask
mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
# 4. Aggregation
y = attn @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(y)
# ==========================================
# 4. Standard Attention (Baseline)
# ==========================================
class StandardAttention(nn.Module):
def __init__(self, d_model, n_head, max_len=128):
super().__init__()
self.n_head = n_head
self.d_head = d_model // n_head
self.scale = 1.0 / math.sqrt(self.d_head)
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.o = nn.Linear(d_model, d_model, bias=False)
# RoPE on Baseline too (since no pos embeddings)
self.rope = RoPE(self.d_head, max_len)
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).split(C, dim=2)
q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)
# Apply RoPE to Q and K
q = self.rope(q)
k = self.rope(k)
scores = (q @ k.transpose(-2, -1)) * self.scale
mask = torch.tril(torch.ones(T, T, device=x.device))
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
y = attn @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.o(y)
# ==========================================
# 5. Model Wrapper
# ==========================================
class TestModel(nn.Module):
def __init__(self, vocab, d_model, n_head, max_len, attn_type='standard'):
super().__init__()
self.embed = nn.Embedding(vocab, d_model)
# NO learned positional embeddings
if attn_type == 'standard':
self.attn = StandardAttention(d_model, n_head, max_len)
elif attn_type == 'triple':
self.attn = TripleProductAttention(d_model, n_head, max_len)
self.unembed = nn.Linear(d_model, vocab, bias=False)
def forward(self, x):
h = self.embed(x)
# h = h + self.pos_emb... (Removed)
h = self.attn(h)
return self.unembed(h)
# ==========================================
# 6. Training Loop
# ==========================================
# 7. Analysis: Resilience of Gram Determinant
# ==========================================
def analyze_gram_resilience():
"""
Analyzes how the Gram Determinant score resists 'False Matches' compared to Dot Product.
False Match Scenario: A vector that is misaligned (orthogonal) but has high magnitude.
"""
print("\n=== Resilience Analysis: Gram Det vs Dot Product ===")
# 1. Setup Reference Plane (Q, R)
# We define a simple 2D plane in 3D space for visualization mental model
# Q and R are unit vectors defining the 'Concept Subspace'
d_dim = 4
q = torch.tensor([1., 0., 0., 0.])
r = torch.tensor([0., 1., 0., 0.])
# 2. Define Scenarios for K
# Scenario A: True Match (Unit vector inside the plane Q-R)
# Ideally 45 degrees between Q and R
k_match = (q + r) / torch.norm(q + r)
# Scenario B: False Match (Orthogonal vector, but HUGE magnitude)
# "Loud Noise" - e.g. a frequent token like 'the' or a positional hub
k_noise_dir = torch.tensor([0., 0., 1., 0.]) # Orthogonal to Q and R
magnitudes = [1.0, 5.0, 10.0, 50.0, 100.0]
print(f"{'Mag(K)':<10} | {'Dot(Q,K)':<12} | {'Gram(Q,K,R)':<15} | {'Result'}")
print("-" * 60)
# Baseline: True Match Score
# Dot Product
score_dot_match = torch.dot(q, k_match).item()
# Gram Det
# Construct Gram Matrix Manually for scalar check
# G = [[q.q, q.k, q.r], [k.q, k.k, k.r], [r.q, r.k, r.r]]
# We implement the function specifically for single vectors
def compute_gram_score(q, k, r):
dq_q = torch.dot(q, q)
dk_k = torch.dot(k, k)
dr_r = torch.dot(r, r)
dq_k = torch.dot(q, k)
dq_r = torch.dot(q, r)
dk_r = torch.dot(k, r)
# Determinant Formula
det = (dq_q * dk_k * dr_r) + \
(2 * dq_k * dk_r * dq_r) - \
(dq_q * dk_r**2) - \
(dk_k * dq_r**2) - \
(dr_r * dq_k**2)
return -det # We maximize negative volume
score_gram_match = compute_gram_score(q, k_match, r).item()
print(f"{'1.0 (Ref)':<10} | {score_dot_match:<12.4f} | {score_gram_match:<15.4f} | True Match")
print("-" * 60)
# Test Noise Vectors
for mag in magnitudes:
k_noise = k_noise_dir * mag
# Dot Product Score (Q . K)
# Note: Since K_noise is strictly orthogonal here, dot is 0.
# But in reality, high dim vectors are "nearly" orthogonal.
# Let's add slight leak (0.1 alignment) to simulate "hubness" or slight correlation
k_leak = k_noise + 0.1 * q
s_dot = torch.dot(q, k_leak).item()
# Gram Score
s_gram = compute_gram_score(q, k_leak, r).item()
print(f"{mag:<10.1f} | {s_dot:<12.4f} | {s_gram:<15.4f} | {'Loud Mismatch'}")
print("\nInterpretation:")
print("1. Standard Dot Product: As Magnitude increases, score INCREASES linearly.")
print(" This causes 'False Matches' if a token is just very frequent (large norm).")
print("2. Gram Determinant: As Magnitude increases for misaligned vectors, score DECREASES (more negative).")
print(" The volume of the parallelepiped grows with the length of the orthogonal component.")
print(" This creates a 'geometric gate': You must be in the plane to play.")
if __name__ == "__main__":
train_comparison()
analyze_gram_resilience()
'''
=== Resilience Analysis: Gram Det vs Dot Product ===
Mag(K) | Dot(Q,K) | Gram(Q,K,R) | Result
------------------------------------------------------------
1.0 (Ref) | 0.7071 | -0.0000 | True Match
------------------------------------------------------------
1.0 | 0.1000 | -1.0000 | Loud Mismatch
5.0 | 0.1000 | -25.0000 | Loud Mismatch
10.0 | 0.1000 | -100.0000 | Loud Mismatch
50.0 | 0.1000 | -2500.0000 | Loud Mismatch
100.0 | 0.1000 | -10000.0000 | Loud Mismatch
Interpretation:
1. Standard Dot Product: As Magnitude increases, score INCREASES linearly.
This causes 'False Matches' if a token is just very frequent (large norm).
2. Gram Determinant: As Magnitude increases for misaligned vectors, score DECREASES (more negative).
The volume of the parallelepiped grows with the length of the orthogonal component.
This creates a 'geometric gate': You must be in the plane to play.
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment