Created
December 3, 2025 03:44
-
-
Save falseywinchnet/abded5e949bd63d1c2d4473965f27fbe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import math | |
| # ========================================== | |
| # 1. Dataset: The "Copy" Task (Induction) | |
| # ========================================== | |
| def generate_copy_batch(batch_size, seq_len, vocab_size): | |
| """ | |
| Generates a sequence where the second half is a copy of the first half. | |
| Task: Predict the next token. | |
| This strictly tests 'Induction Heads' (finding the previous instance of current token). | |
| """ | |
| # Ensure even sequence length | |
| if seq_len % 2 != 0: seq_len += 1 | |
| half = seq_len // 2 | |
| # Generate random data [B, Seq] | |
| data = torch.randint(0, vocab_size, (batch_size, seq_len)) | |
| # Force copy: second half = first half | |
| data[:, half:] = data[:, :half] | |
| # Input is sequence, Target is shifted by 1 | |
| inputs = data[:, :-1] | |
| targets = data[:, 1:] | |
| return inputs, targets | |
| # ========================================== | |
| # 2. Autograder's RoPE Implementation | |
| # ========================================== | |
| class RoPE(nn.Module): | |
| def __init__(self, dim, max_len=4096): | |
| super().__init__() | |
| self.dim = dim | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| t = torch.arange(max_len).float() | |
| freqs = torch.einsum('i,j->ij', t, inv_freq) | |
| self.register_buffer('cos', freqs.cos()) | |
| self.register_buffer('sin', freqs.sin()) | |
| def forward(self, x): | |
| # x shape: [B, H, T, D] | |
| # shape[2] is T | |
| seq_len = x.shape[2] | |
| cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0) | |
| sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0) | |
| x1 = x[..., 0::2] | |
| x2 = x[..., 1::2] | |
| return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1) | |
| # ========================================== | |
| # 3. Triple Product Attention (Exterior Dot Product) | |
| # ========================================== | |
| class TripleProductAttention(nn.Module): | |
| def __init__(self, d_model, n_head, max_len=128): | |
| super().__init__() | |
| assert d_model % n_head == 0 | |
| self.d_head = d_model // n_head | |
| self.n_head = n_head | |
| self.W_q = nn.Linear(d_model, d_model, bias=False) | |
| self.W_k = nn.Linear(d_model, d_model, bias=False) | |
| self.W_r = nn.Linear(d_model, d_model, bias=False) | |
| self.W_v = nn.Linear(d_model, d_model, bias=False) | |
| self.W_o = nn.Linear(d_model, d_model, bias=False) | |
| # RoPE is mandatory and applied to everything (Q, K, R) | |
| self.rope = RoPE(self.d_head, max_len) | |
| # Scale for volume-based attention | |
| # Since Det is degree 6 (product of 3 squared terms essentially), | |
| # and standard attention is degree 2 (Q*K), we need aggressive scaling. | |
| self.scale = 1.0 / (self.d_head ** 1.5) | |
| def forward(self, x): | |
| # x: [B, T, C] | |
| B, T, C = x.shape | |
| # 1. Projections -> [B, T, H, D] -> Transpose to [B, H, T, D] | |
| q = self.W_q(x).view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
| k = self.W_k(x).view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
| r = self.W_r(x).view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
| v = self.W_v(x).view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
| # 2. Apply RoPE (On Everything: Q, K, R) | |
| q = self.rope(q) | |
| k = self.rope(k) | |
| r = self.rope(r) | |
| # 3. Exterior Dot Product Logic (Q^K^R) | |
| # We calculate the Squared Volume of the parallelepiped (Q, K, R) | |
| # This is the Determinant of the Gram Matrix G. | |
| # G = [[q.q, q.k, q.r], [k.q, k.k, k.r], [r.q, r.k, r.r]] | |
| # We want to detect dependency (Zero Volume) -> High Score. | |
| # Score = -Det(G) | |
| # Compute dot product terms | |
| # Q aligned at index i, K aligned at index j, R aligned at index i (reference) | |
| # Norms (Diagonal terms) [B, H, T, 1] | |
| d_qq = (q * q).sum(dim=-1, keepdim=True) | |
| d_rr = (r * r).sum(dim=-1, keepdim=True) | |
| d_rq = (r * q).sum(dim=-1, keepdim=True) # Also a scalar per query | |
| # Key terms need transpose to broadcast over j | |
| d_kk = (k * k).sum(dim=-1, keepdim=True).transpose(-2, -1) # [B, H, 1, T] | |
| # Interaction terms [B, H, T, T] | |
| d_qk = q @ k.transpose(-2, -1) # Q_i . K_j | |
| d_rk = r @ k.transpose(-2, -1) # R_i . K_j | |
| # Determinant of 3x3 Gram Matrix | |
| # det = A(EI - FH) - B(DI - FG) + C(DH - EG) ... simplified Sarrus: | |
| # det = qq*kk*rr + 2*qk*kr*rq - qq*kr^2 - kk*rq^2 - rr*qk^2 | |
| # Broadcast all to [B, H, T, T] | |
| term1 = d_qq * d_kk * d_rr | |
| term2 = 2.0 * d_qk * d_rk * d_rq | |
| term3 = - (d_qq * (d_rk ** 2)) | |
| term4 = - (d_kk * (d_rq ** 2)) | |
| term5 = - (d_rr * (d_qk ** 2)) | |
| gram_det = term1 + term2 + term3 + term4 + term5 | |
| # Score is negative volume (Dependency Search) | |
| scores = -gram_det * self.scale | |
| # Causal Mask | |
| mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| attn = F.softmax(scores, dim=-1) | |
| # 4. Aggregation | |
| y = attn @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.W_o(y) | |
| # ========================================== | |
| # 4. Standard Attention (Baseline) | |
| # ========================================== | |
| class StandardAttention(nn.Module): | |
| def __init__(self, d_model, n_head, max_len=128): | |
| super().__init__() | |
| self.n_head = n_head | |
| self.d_head = d_model // n_head | |
| self.scale = 1.0 / math.sqrt(self.d_head) | |
| self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) | |
| self.o = nn.Linear(d_model, d_model, bias=False) | |
| # RoPE on Baseline too (since no pos embeddings) | |
| self.rope = RoPE(self.d_head, max_len) | |
| def forward(self, x): | |
| B, T, C = x.shape | |
| q, k, v = self.qkv(x).split(C, dim=2) | |
| q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
| k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
| v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2) | |
| # Apply RoPE to Q and K | |
| q = self.rope(q) | |
| k = self.rope(k) | |
| scores = (q @ k.transpose(-2, -1)) * self.scale | |
| mask = torch.tril(torch.ones(T, T, device=x.device)) | |
| scores = scores.masked_fill(mask == 0, float('-inf')) | |
| attn = F.softmax(scores, dim=-1) | |
| y = attn @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.o(y) | |
| # ========================================== | |
| # 5. Model Wrapper | |
| # ========================================== | |
| class TestModel(nn.Module): | |
| def __init__(self, vocab, d_model, n_head, max_len, attn_type='standard'): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab, d_model) | |
| # NO learned positional embeddings | |
| if attn_type == 'standard': | |
| self.attn = StandardAttention(d_model, n_head, max_len) | |
| elif attn_type == 'triple': | |
| self.attn = TripleProductAttention(d_model, n_head, max_len) | |
| self.unembed = nn.Linear(d_model, vocab, bias=False) | |
| def forward(self, x): | |
| h = self.embed(x) | |
| # h = h + self.pos_emb... (Removed) | |
| h = self.attn(h) | |
| return self.unembed(h) | |
| # ========================================== | |
| # 6. Training Loop | |
| # ========================================== | |
| # 7. Analysis: Resilience of Gram Determinant | |
| # ========================================== | |
| def analyze_gram_resilience(): | |
| """ | |
| Analyzes how the Gram Determinant score resists 'False Matches' compared to Dot Product. | |
| False Match Scenario: A vector that is misaligned (orthogonal) but has high magnitude. | |
| """ | |
| print("\n=== Resilience Analysis: Gram Det vs Dot Product ===") | |
| # 1. Setup Reference Plane (Q, R) | |
| # We define a simple 2D plane in 3D space for visualization mental model | |
| # Q and R are unit vectors defining the 'Concept Subspace' | |
| d_dim = 4 | |
| q = torch.tensor([1., 0., 0., 0.]) | |
| r = torch.tensor([0., 1., 0., 0.]) | |
| # 2. Define Scenarios for K | |
| # Scenario A: True Match (Unit vector inside the plane Q-R) | |
| # Ideally 45 degrees between Q and R | |
| k_match = (q + r) / torch.norm(q + r) | |
| # Scenario B: False Match (Orthogonal vector, but HUGE magnitude) | |
| # "Loud Noise" - e.g. a frequent token like 'the' or a positional hub | |
| k_noise_dir = torch.tensor([0., 0., 1., 0.]) # Orthogonal to Q and R | |
| magnitudes = [1.0, 5.0, 10.0, 50.0, 100.0] | |
| print(f"{'Mag(K)':<10} | {'Dot(Q,K)':<12} | {'Gram(Q,K,R)':<15} | {'Result'}") | |
| print("-" * 60) | |
| # Baseline: True Match Score | |
| # Dot Product | |
| score_dot_match = torch.dot(q, k_match).item() | |
| # Gram Det | |
| # Construct Gram Matrix Manually for scalar check | |
| # G = [[q.q, q.k, q.r], [k.q, k.k, k.r], [r.q, r.k, r.r]] | |
| # We implement the function specifically for single vectors | |
| def compute_gram_score(q, k, r): | |
| dq_q = torch.dot(q, q) | |
| dk_k = torch.dot(k, k) | |
| dr_r = torch.dot(r, r) | |
| dq_k = torch.dot(q, k) | |
| dq_r = torch.dot(q, r) | |
| dk_r = torch.dot(k, r) | |
| # Determinant Formula | |
| det = (dq_q * dk_k * dr_r) + \ | |
| (2 * dq_k * dk_r * dq_r) - \ | |
| (dq_q * dk_r**2) - \ | |
| (dk_k * dq_r**2) - \ | |
| (dr_r * dq_k**2) | |
| return -det # We maximize negative volume | |
| score_gram_match = compute_gram_score(q, k_match, r).item() | |
| print(f"{'1.0 (Ref)':<10} | {score_dot_match:<12.4f} | {score_gram_match:<15.4f} | True Match") | |
| print("-" * 60) | |
| # Test Noise Vectors | |
| for mag in magnitudes: | |
| k_noise = k_noise_dir * mag | |
| # Dot Product Score (Q . K) | |
| # Note: Since K_noise is strictly orthogonal here, dot is 0. | |
| # But in reality, high dim vectors are "nearly" orthogonal. | |
| # Let's add slight leak (0.1 alignment) to simulate "hubness" or slight correlation | |
| k_leak = k_noise + 0.1 * q | |
| s_dot = torch.dot(q, k_leak).item() | |
| # Gram Score | |
| s_gram = compute_gram_score(q, k_leak, r).item() | |
| print(f"{mag:<10.1f} | {s_dot:<12.4f} | {s_gram:<15.4f} | {'Loud Mismatch'}") | |
| print("\nInterpretation:") | |
| print("1. Standard Dot Product: As Magnitude increases, score INCREASES linearly.") | |
| print(" This causes 'False Matches' if a token is just very frequent (large norm).") | |
| print("2. Gram Determinant: As Magnitude increases for misaligned vectors, score DECREASES (more negative).") | |
| print(" The volume of the parallelepiped grows with the length of the orthogonal component.") | |
| print(" This creates a 'geometric gate': You must be in the plane to play.") | |
| if __name__ == "__main__": | |
| train_comparison() | |
| analyze_gram_resilience() | |
| ''' | |
| === Resilience Analysis: Gram Det vs Dot Product === | |
| Mag(K) | Dot(Q,K) | Gram(Q,K,R) | Result | |
| ------------------------------------------------------------ | |
| 1.0 (Ref) | 0.7071 | -0.0000 | True Match | |
| ------------------------------------------------------------ | |
| 1.0 | 0.1000 | -1.0000 | Loud Mismatch | |
| 5.0 | 0.1000 | -25.0000 | Loud Mismatch | |
| 10.0 | 0.1000 | -100.0000 | Loud Mismatch | |
| 50.0 | 0.1000 | -2500.0000 | Loud Mismatch | |
| 100.0 | 0.1000 | -10000.0000 | Loud Mismatch | |
| Interpretation: | |
| 1. Standard Dot Product: As Magnitude increases, score INCREASES linearly. | |
| This causes 'False Matches' if a token is just very frequent (large norm). | |
| 2. Gram Determinant: As Magnitude increases for misaligned vectors, score DECREASES (more negative). | |
| The volume of the parallelepiped grows with the length of the orthogonal component. | |
| This creates a 'geometric gate': You must be in the plane to play. | |
| ''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment