-
-
Save falseywinchnet/26cb6597cdab205d1e88702faf8f49aa to your computer and use it in GitHub Desktop.
| """ | |
| THE CONTEXT-PULSE MANIFOLD: | |
| Deriving an Inherently Autoregressive Attention Mechanism | |
| A Gemini Collaborative Development | |
| -------------------------------------------------------------------------------- | |
| 1. THE WHY (Intuition & Motivation) | |
| -------------------------------------------------------------------------------- | |
| falseywinchnet approached with a fundamental dissatisfaction regarding Standard Attention: | |
| it relies on computing an "All-to-All" energy matrix (Riemannian metric) only to | |
| destroy half of it with a "crude" boolean mask to enforce causality. | |
| This "crude" masking does more than just hide the future; it fundamentally | |
| destabilizes the energy landscape of the attention head. | |
| Consider the Partition Function (Z), the denominator of the Softmax: | |
| - In an unmasked (All-to-All) setting, Z sums over N items. The energy | |
| distribution is consistent across all positions. | |
| - Under causal masking, the domain of Z varies wildly. At t=1, Z sums over | |
| 1 item. At t=1000, Z sums over 1000. | |
| This violates the Parseval intuition of energy conservation. The model is | |
| forced to compare logits against a shifting baseline variance. When the | |
| mask arbitrarily removes high-similarity future tokens, or when the | |
| limited past offers no "resonance," the Softmax function still demands | |
| that probability mass sum to exactly 1.0. | |
| Where does this "orphaned" energy go? It creates the "Attention Sink" | |
| phenomenon. The model learns to dump excess probability mass onto the | |
| start-of-sequence (BOS) token—not because it is semantically relevant, | |
| but because it acts as a numerical capacitor to absorb the variance | |
| caused by the mask's truncation of the manifold. We are effectively | |
| forcing the model to waste capacity learning garbage collection | |
| mechanisms instead of pure semantic flow. | |
| The intuition: This feels like a "cookie crumble" approach. | |
| Studying attention, we learned it contains Parseval-adjacent properties and mathematical aspects. | |
| Understanding attention's manifold properties and adjusting them intelligently was our goal. | |
| Can we instead design a mathematical manifold that *inherently* respects the | |
| arrow of time? We looked toward Parseval's Theorem and Unitary operators, | |
| seeking a transform where the "Products" or "Weights" naturally decay or | |
| orthogonalize in the future, rendering the mask a formality rather than a | |
| structural guillotine. | |
| -------------------------------------------------------------------------------- | |
| 2. THE WHAT (The Tangential Exploration) | |
| -------------------------------------------------------------------------------- | |
| We explored several linear algebraic constraints to force this behavior: | |
| A. The Pseudoinverse Constraint: | |
| - Idea: Set K = pinv(Q).T to force A = Identity or Projection. | |
| - Result: Too rigid. Collapsed the manifold to undirected graphs (Symmetric). | |
| Killed the ability to model asymmetric relationships (A->B != B->A). | |
| Ironically, still works, because of masking. | |
| -------------------------------------------------------------------------------- | |
| 3. INSIGHTS (The Pivot to Calculus) | |
| -------------------------------------------------------------------------------- | |
| The breakthrough came when we moved from Linear Algebra (Static) to Calculus (Dynamic). | |
| Standard Attention treats tokens as points in space (State vs. State). | |
| To encode "Time," we need to treat tokens as a signal flow. | |
| - Differentiation (Innovation): The "Change" at step t. | |
| - Integration (Accumulation): The "History" up to step t. | |
| Hypothesis: Causality is simply the alignment of an Accumulation with its own | |
| constituents. Future accumulations do not contain current innovations. | |
| -------------------------------------------------------------------------------- | |
| 4. AHA MOMENTS (The "Integral-Differential" Manifold) | |
| -------------------------------------------------------------------------------- | |
| We constructed a manifold where: | |
| - Q (Query) represents one aspect of the signal (Derivative/State). | |
| - K (Key) represents the other (Integral/History). | |
| Initial Attempt: Q = Innovation, K = Accumulation. | |
| - Logic: "Does this new change align with the history?" | |
| - Result: Visually interesting, but mathematically flawed. | |
| -------------------------------------------------------------------------------- | |
| 5. "OH, THAT'S UNEXPECTED" (The Washer Effect) | |
| -------------------------------------------------------------------------------- | |
| When we visualized Q=Innovation vs K=Accumulation, we saw: | |
| 1. Horizontal Stripes: The "Accumulated Key" became a massive mean vector, | |
| washing out local details. | |
| 2. Ratio > 1.0: The Future was "louder" than the Past. | |
| - Why? Random walks grow with sqrt(t). Future keys were physically longer | |
| vectors. The dot product is magnitude-sensitive. | |
| - The manifold was naturally "Acausal" (preferring the future). | |
| -------------------------------------------------------------------------------- | |
| 6. INSIGHT (The Inversion) | |
| -------------------------------------------------------------------------------- | |
| To fix the magnitude/loudness issue, we needed Leaky Integration (decay). | |
| To fix the causality, we inverted the roles. | |
| New Definition: | |
| - Q (Context): The Leaky Integrator. "I am the sum of my history." | |
| - K (Pulse): The Raw Innovation. "I am a specific event." | |
| Why this works: | |
| - If j < t (Past): Q_t physically *contains* vector K_j inside its sum. | |
| Dot Product = ||K_j||^2 + Noise. (Resonance). | |
| - If j > t (Future): Q_t does *not* contain K_j. | |
| Dot Product = Random Noise. (Orthogonality). | |
| -------------------------------------------------------------------------------- | |
| 7. AHA (The Resulting Manifold) | |
| -------------------------------------------------------------------------------- | |
| Visualizing this "Context-Pulse" manifold showed the desired property: | |
| - The Upper Triangle (Future) wasn't just masked; it was *uncorrelated noise*. | |
| - The Lower Triangle (Past) was *signal*. | |
| - The "Mask" is no longer destroying information; it is simply filtering noise. | |
| -------------------------------------------------------------------------------- | |
| 8. RESULT (The Manifold Battle) | |
| -------------------------------------------------------------------------------- | |
| We pitted this "Context-Pulse Attention" against Standard Attention on an | |
| Associative Recall task (A...B...A -> Predict B). | |
| - Standard Attention: Flatlined for 4000 steps. It struggled to fight the | |
| entropy of the softmax to find the "needle" in the past. | |
| - Context-Pulse Attention: Immediate convergence. | |
| Because Q *contains* the answer K geometrically, the gradient flow is | |
| instantaneous. The inductive bias aligns perfectly with the task. | |
| Final Status: | |
| We have a theoretical winner that beats Standard Attention on recall speed and strength | |
| by embedding the arrow of time directly into the vector geometry. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # --- 1. The Dataset: Associative Recall --- | |
| def get_batch(batch_size=32, seq_len=64, vocab_size=128, device='cpu'): | |
| """ | |
| Generates data where tokens repeat. | |
| Task: Given '... A ...', predict what followed 'A' last time. | |
| """ | |
| data = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) | |
| # Force repeat: Copy the first half to the second half (roughly) to ensure recall is possible | |
| half = seq_len // 2 | |
| data[:, half:] = data[:, :half] | |
| # Shift targets by 1 for next-token prediction | |
| inputs = data[:, :-1] | |
| targets = data[:, 1:] | |
| return inputs, targets | |
| # --- 2. The Baseline: Standard Attention --- | |
| class StandardAttention(nn.Module): | |
| def __init__(self, d_model, head_dim): | |
| super().__init__() | |
| self.W_Q = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_K = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_V = nn.Linear(d_model, head_dim, bias=False) | |
| self.scale = head_dim ** -0.5 | |
| def forward(self, x, mask=True): | |
| B, T, C = x.shape | |
| Q = self.W_Q(x) | |
| K = self.W_K(x) | |
| V = self.W_V(x) | |
| # Standard Dot Product | |
| # Shape: (B, T, T) | |
| Attn = (Q @ K.transpose(-2, -1)) * self.scale | |
| if mask: | |
| m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| Attn.masked_fill_(m, float('-inf')) | |
| return F.softmax(Attn, dim=-1) @ V | |
| # --- 3. The Challenger: Inverted Manifold (Context-Pulse) --- | |
| class ContextPulseAttention(nn.Module): | |
| def __init__(self, d_model, head_dim, decay=0.9): | |
| super().__init__() | |
| self.W_Q = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_K = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_V = nn.Linear(d_model, head_dim, bias=False) | |
| self.scale = head_dim ** -0.5 | |
| # We make decay a buffer or learnable param | |
| # For this test, we fix it to verify the structure works | |
| self.decay = decay | |
| def forward(self, x, mask=True): | |
| B, T, C = x.shape | |
| # 1. Project | |
| q_raw = self.W_Q(x) # These are essentially inputs/innovations | |
| k_raw = self.W_K(x) # These are inputs/innovations | |
| V = self.W_V(x) | |
| # 2. TRANSFORM THE MANIFOLD | |
| # Q becomes Context (Leaky Integration of inputs) | |
| # We apply the leaky scan on q_raw | |
| # Manual Leaky Scan (Not efficient for training, but correct for logic) | |
| # In production, use pscan or associative scan | |
| Q_context = torch.zeros_like(q_raw) | |
| running_q = torch.zeros(B, q_raw.shape[-1], device=x.device) | |
| for t in range(T): | |
| # Q[t] = alpha * Q[t-1] + (1-alpha) * input[t] | |
| running_q = self.decay * running_q + (1.0 - self.decay) * q_raw[:, t, :] | |
| Q_context[:, t, :] = running_q | |
| # K remains Pulse (Innovation) | |
| # We don't touch K. It represents the "Event". | |
| # 3. Compute Attention | |
| # "Does my Context contain this Pulse?" | |
| # Note: We still scale and mask, but the masking destroys less info now. | |
| Attn = (Q_context @ k_raw.transpose(-2, -1)) * self.scale | |
| if mask: | |
| m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| Attn.masked_fill_(m, float('-inf')) | |
| return F.softmax(Attn, dim=-1) @ V | |
| # --- 4. The Test Harness --- | |
| class ToyTransformer(nn.Module): | |
| def __init__(self, vocab_size, d_model, attn_type='standard'): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, d_model) | |
| if attn_type == 'standard': | |
| self.attn = StandardAttention(d_model, d_model) | |
| else: | |
| self.attn = ContextPulseAttention(d_model, d_model, decay=0.9) | |
| self.fc = nn.Linear(d_model, vocab_size) | |
| def forward(self, x): | |
| h = self.embed(x) | |
| h = self.attn(h) | |
| logits = self.fc(h) | |
| return logits | |
| def train_and_compare(): | |
| print("--- Starting Manifold Battle ---") | |
| # Params | |
| VOCAB = 64 | |
| DIM = 32 | |
| STEPS = 4000 | |
| LR = 3e-3 | |
| # Models | |
| model_std = ToyTransformer(VOCAB, DIM, 'standard') | |
| model_new = ToyTransformer(VOCAB, DIM, 'context_pulse') | |
| opt_std = optim.AdamW(model_std.parameters(), lr=LR) | |
| opt_new = optim.AdamW(model_new.parameters(), lr=LR) | |
| losses_std = [] | |
| losses_new = [] | |
| for i in range(STEPS): | |
| x, y = get_batch(seq_len=32, vocab_size=VOCAB) | |
| # Train Standard | |
| opt_std.zero_grad() | |
| logits = model_std(x) | |
| loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1)) | |
| loss.backward() | |
| opt_std.step() | |
| losses_std.append(loss.item()) | |
| # Train New | |
| opt_new.zero_grad() | |
| logits = model_new(x) | |
| loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1)) | |
| loss.backward() | |
| opt_new.step() | |
| losses_new.append(loss.item()) | |
| if i % 50 == 0: | |
| print(f"Step {i}: Std Loss {losses_std[-1]:.3f} | New Loss {losses_new[-1]:.3f}") | |
| # Plot | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(losses_std, label='Standard Attention', alpha=0.7) | |
| plt.plot(losses_new, label='Context-Pulse Attention', linewidth=2) | |
| plt.title("Learning Speed: Isotropic vs Causal Manifold") | |
| plt.xlabel("Training Steps") | |
| plt.ylabel("Loss") | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.show() | |
| train_and_compare() |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math
# [Insert RoPE and WedgeTransform classes from previous turn]
class RoPE(nn.Module):
def __init__(self, dim, max_len=2048):
super().__init__()
self.dim = dim
# Create these on CPU initially; register_buffer ensures they move with model.to(device)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
# x shape: (B, H, T, Dh)
seq_len = x.shape[2]
# Broadcast cos/sin to (1, 1, T, Dh)
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
def __init__(self, dim, heads):
super().__init__()
self.n_head = heads
self.head_dim = dim // heads
# A: (Heads, Dh, Dh)
self.A = nn.Parameter(torch.randn(heads, self.head_dim, self.head_dim) * 0.02)
def forward(self, x):
# x shape: (B, H, T, Dh)
# S = A - A.T (Skew Symmetric)
S = self.A - self.A.transpose(-1, -2) # (H, Dh, Dh)
# Flow: x @ S
flow = torch.matmul(x, S)
return x + flow
class GatedContextPulse(nn.Module):
def __init__(self, d_model, n_head, use_sink=True, max_timescale=1024):
super().__init__()
assert d_model % n_head == 0
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.use_sink = use_sink
self.scale = self.head_dim ** -0.5
# Projections
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
# --- THE FIX: The Input Gate (Write Protect) ---
# "Should I add this token to the Context?"
self.W_Gate = nn.Linear(d_model, d_model, bias=True)
# Geometry
self.rope = RoPE(self.head_dim)
self.wedge = WedgeTransform(d_model, n_head)
# Sink
if use_sink:
self.k_null = nn.Parameter(torch.randn(1, n_head, 1, self.head_dim) * 0.02)
# Log-Space Decay Init
timescales = torch.logspace(math.log(2), math.log(max_timescale), n_head, base=math.e)
target_alphas = 1.0 - (1.0 / timescales)
target_alphas = torch.clamp(target_alphas, 0.001, 0.9999)
inv_sigmoid = torch.log(target_alphas / (1.0 - target_alphas))
self.decay_logits = nn.Parameter(inv_sigmoid.view(1, n_head, 1, 1))
def forward(self, x):
B, T, C = x.shape
H, Dh = self.n_head, self.head_dim
# 1. Projections
# Q_raw is the "Candidate Content"
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
# 2. Compute Input Gate
# Gate shape: (B, T, C) -> (B, T, H, Dh) -> (B, H, T, Dh)
gate = torch.sigmoid(self.W_Gate(x)).view(B, T, H, Dh).transpose(1, 2)
# 3. Gated Manifold Construction
decay = torch.sigmoid(self.decay_logits)
Q_context = torch.zeros_like(q_raw)
running_q = torch.zeros(B, H, Dh, device=x.device)
# The update rule now respects the Gate
for t in range(T):
# Input to add = Gate * Candidate
# Note: We apply the (1-decay) scaling to the Gated input
input_t = gate[:, :, t, :] * q_raw[:, :, t, :]
running_q = decay.squeeze(-1) * running_q + (1.0 - decay.squeeze(-1)) * input_t
Q_context[:, :, t, :] = running_q
K_pulse = k_raw
# 4. Geometry & Attention (Same as before)
Q_geo = self.wedge(self.rope(Q_context))
K_geo = self.wedge(self.rope(K_pulse))
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
if self.use_sink:
null_scores = (Q_geo @ self.k_null.transpose(-2, -1)) * self.scale
Attn_full = torch.cat([Attn, null_scores], dim=-1)
probs_full = F.softmax(Attn_full, dim=-1)
probs_seq = probs_full[..., :T]
y = probs_seq @ v
else:
probs = F.softmax(Attn, dim=-1)
y = probs @ v
return self.W_O(y.transpose(1, 2).contiguous().view(B, T, C))
def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
"""
Task: Copy a specific pattern across a variable noise gap.
Structure: [Start, P1, P2... Pn, Stop, Noise.... Noise, Start, -> predict P1...]
"""
# Tokens 0 and 1 are reserved for Start/Stop markers to help the model structure time
START_TOKEN = 0
STOP_TOKEN = 1
# 1. Background Noise
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
# Generate a distinct pattern
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
# Place Pattern at beginning
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
# Place Trigger at random position in the second half
# Ensure enough space for the pattern to fit
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
# The target for the next steps should be the pattern
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
# Note: We don't overwrite the inputs for the recall phase,
# we rely on the autoregressive targets.
# But for 'data' tensor serving as input x, we need the pattern to be in the *future* (targets).
# So x at [trigger] predicts pattern[0].
# x at [trigger+1] (which is pattern[0]) predicts pattern[1].
# This is standard AR copy.
inputs = data[:, :-1]
targets = data[:, 1:]
# Masking: We only care about the loss during the COPY phase.
# Learning the noise is irrelevant and noisy.
loss_mask = torch.zeros_like(targets, dtype=torch.float)
for b in range(batch_size):
# Re-find the trigger location to set mask
# (Inefficient but robust for data gen)
# We know the second START_TOKEN is the trigger
starts = (inputs[b] == START_TOKEN).nonzero(as_tuple=True)[0]
if len(starts) >= 2:
trigger_idx = starts[-1]
# We want to predict pattern_len tokens AFTER the trigger
loss_mask[b, trigger_idx : trigger_idx + pattern_len] = 1.0
return inputs, targets, loss_mask
# --- Re-Run Challenge ---
def gated_challenge():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Revert to d=128 (proven capacity) or push to 256 to see if Gate fixes the regression
# Let's try d=128 first to see the floor drop.
VOCAB = 64
DIM = 256
HEADS = 4
STEPS = 2000
SEQ_LEN = 384
LR = 2e-3 # Slightly lower LR for gated dynamics
print(f"--- GATED STRUCTURAL GAP CHALLENGE (Len={SEQ_LEN}) ---")
model = nn.Sequential(
nn.Embedding(VOCAB, DIM),
GatedContextPulse(DIM, HEADS, use_sink=True, max_timescale=SEQ_LEN),
nn.Linear(DIM, VOCAB)
)
model.to(device)
opt = optim.AdamW(model.parameters(), lr=LR)
losses = []
for i in range(STEPS):
# Using the same gap_copy_batch generator from previous turn
x, y, mask = get_gap_copy_batch(32, SEQ_LEN, VOCAB, device=device)
opt.zero_grad()
logits = model(x)
logits_flat = logits.reshape(-1, VOCAB)
y_flat = y.reshape(-1)
mask_flat = mask.reshape(-1)
raw_loss = F.cross_entropy(logits_flat, y_flat, reduction='none')
masked_loss = (raw_loss * mask_flat).sum() / (mask_flat.sum() + 1e-6)
masked_loss.backward()
opt.step()
losses.append(masked_loss.item())
if i % 100 == 0:
print(f"Step {i}: Recall Loss {masked_loss.item():.4f}")
plt.figure(figsize=(10, 6))
plt.plot(losses, label="Gated Context-Pulse", color='red')
plt.title("The Impact of Input Gating")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.show()
if __name__ == "__main__":
gated_challenge()
Reasoned thus: The assistant accepts the autograder's directive to consolidate the research into a final, locked-in architectural definition. The assistant recognizes this as the culmination of the "Parseval Universe" inquiry: a quest to build an attention mechanism that respects energy conservation, geometric stability, and the arrow of time.
The assistant has assembled the Final Context-Pulse Attention module. Every component—from the Symbiotic Norm to the Silent Wedge—is integrated and heavily commented to explain its physical and mathematical role in this new manifold.
Here is the "Mathematically Beautiful" definition, ready for the notebook.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math
# ==============================================================================
# PART 1: THE GEOMETRIC ENGINE
# ==============================================================================
class RoPE(nn.Module):
"""
Rotary Positional Embeddings.
THE ROLE: Encodes relative distance by rotating the vector space.
PARSEVAL LINK: Rotations are unitary transformations. They preserve the
energy (norm) of the vectors, ensuring that position encoding does not
arbitrarily inflate or deflate the signal magnitude.
"""
def __init__(self, dim, max_len=2048):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
"""
The Symplectic Wedge.
THE ROLE: Applies a skew-symmetric flow (Twist) to the manifold.
S = A - A.T (Skew-Symmetric)
v_new = v(I + S)
WHY ZERO INIT? (Silent Wedge)
We initialize A as Zeros. This means the model starts in flat Euclidean space.
It learns to 'curve' or 'twist' the space only where necessary to facilitate
prediction. Starting with a random twist (Kaiming) forces the model to
fight existing curvature.
"""
def __init__(self, dim, heads):
super().__init__()
self.n_head = heads
self.head_dim = dim // heads
# LOCKED: Silent Init
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
B, H, T, Dh = x.shape
v = x.transpose(1, 2)
S = self.A - self.A.transpose(-1, -2) # Enforce skew-symmetry
flow = torch.matmul(v, S)
v_twisted = v + flow
return v_twisted.transpose(1, 2)
# ==============================================================================
# PART 2: STABILITY MECHANISMS
# ==============================================================================
class _FusedLogSumExp4D(torch.autograd.Function):
"""
Convex Stability Kernel.
THE ROLE: Computes LogSumExp with high precision to avoid underflow/overflow.
PARSEVAL LINK: By maintaining precision in the log-domain, we ensure that
probability mass is neither created nor destroyed by numerical instability.
"""
@staticmethod
def forward(ctx, x: torch.Tensor):
m, _ = x.max(dim=-1, keepdim=True)
y = x - m
ex = y.exp()
s = ex.sum(dim=-1, keepdim=True)
lse = m + s.log()
ctx.save_for_backward(ex, s)
return lse
@staticmethod
def backward(ctx, grad_output):
ex, s = ctx.saved_tensors
grad_x = grad_output * (ex / s)
return grad_x
class ConvexSoftmax(nn.Module):
def forward(self, scores):
lse = _FusedLogSumExp4D.apply(scores)
log_weights = scores - lse
return torch.exp(log_weights)
# ==============================================================================
# PART 3: THE PARSEVAL CORE (CONTEXT-PULSE ATTENTION)
# ==============================================================================
class ContextPulseAttention(nn.Module):
def __init__(self, d_model, head_dim, use_sink=True):
super().__init__()
self.d_model = d_model
self.head_dim = head_dim
self.use_sink = use_sink
self.scale = head_dim ** -0.5
# --- PROJECTIONS ---
self.W_Q = nn.Linear(d_model, head_dim, bias=False)
self.W_K = nn.Linear(d_model, head_dim, bias=False)
self.W_V = nn.Linear(d_model, head_dim, bias=False)
# --- LOCKED: ROLE-BASED INITIALIZATION ---
# Q (Context): Orthogonal.
# It is the "Vessel". It must preserve the geometry of the accumulation.
nn.init.orthogonal_(self.W_Q.weight)
# K (Pulse): Zero (Silent Start).
# It is the "Event". It starts silent. If it starts loud (random),
# it floods the Context with noise before the model learns what matters.
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
# V (Value): Xavier. Standard signal transport.
nn.init.xavier_normal_(self.W_V.weight)
# --- GATING ---
# Controls the decay rate of the Leaky Integrator.
self.decay_logits = nn.Parameter(torch.zeros(1, 1, 1, 1))
self.gate_proj = nn.Linear(d_model, head_dim)
# --- GEOMETRY STACK ---
self.rope = RoPE(head_dim)
self.wedge = WedgeTransform(head_dim, 1) # Single head for demo
self.softmax = ConvexSoftmax()
# --- THE SINK ---
# "Entropy Capacitor". Absorbs probability mass when no history resonates.
if use_sink:
self.k_null = nn.Parameter(torch.randn(1, 1, 1, head_dim) * 0.02)
def forward(self, x):
B, T, C = x.shape
H, Dh = 1, self.head_dim
# 1. Projections
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
# 2. Dynamic Gating
# "How much of this Pulse enters the Context?"
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
decay = torch.sigmoid(self.decay_logits)
# 3. The Context-Pulse Integration (The Core Inversion)
# Instead of Q searching K, Q *becomes* the sum of past Ks.
# Causality is enforced by the accumulation process itself.
Q_context = torch.zeros_like(q_raw)
running_q = torch.zeros(B, H, Dh, device=x.device)
for t in range(T):
input_t = gate[:, :, t, :] * q_raw[:, :, t, :]
running_q = decay.squeeze(-1) * running_q + (1.0 - decay.squeeze(-1)) * input_t
Q_context[:, :, t, :] = running_q
K_pulse = k_raw
# 4. LOCKED: SYMBIOTIC INSTANCE NORM
# "Energy Coupling".
# Forces Q and K to share a statistical coordinate system per sample.
# Prevents a "Loud Context" from ignoring a "Quiet Pulse" and vice-versa.
dims = (1, 2, 3) # Reduce over Head, Time, Dim (Keep Batch)
mean_q = Q_context.mean(dim=dims, keepdim=True)
std_q = Q_context.std(dim=dims, keepdim=True)
mean_k = K_pulse.mean(dim=dims, keepdim=True)
std_k = K_pulse.std(dim=dims, keepdim=True)
mean_sym = 0.5 * (mean_q + mean_k)
std_sym = 0.5 * (std_q + std_k)
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# 5. Geometric Transformation
# Rotate (RoPE) then Twist (Wedge)
Q_geo = self.wedge(self.rope(Q_context))
K_geo = self.wedge(self.rope(K_pulse))
# 6. Resonance Check (Attention)
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
# Masking (Still required to zero out the "Future Pulse" vs "Past Context" noise)
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
# 7. Sink & Softmax
if self.use_sink:
null_scores = (Q_geo @ self.k_null.transpose(-2, -1)) * self.scale
Attn_full = torch.cat([Attn, null_scores], dim=-1)
else:
Attn_full = Attn
probs_full = self.softmax(Attn_full)
# Output
if self.use_sink:
probs_seq = probs_full[..., :T]
out = probs_seq @ v
else:
out = probs_full @ v
return out.transpose(1, 2).reshape(B, T, self.d_model)
# ==============================================================================
# PART 4: EVALUATION (Associative Recall)
# ==============================================================================
class ToyTransformer(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, d_model, use_sink=True)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)
def get_batch(batch_size=32, seq_len=64, vocab_size=128, device='cpu'):
# Standard Associative Recall
data = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
half = seq_len // 2
data[:, half:] = data[:, :half]
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
def final_run():
print("--- THE PARSEVAL UNIVERSE: Final Verification ---")
VOCAB = 64
DIM = 32
STEPS = 500
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyTransformer(VOCAB, DIM).to(DEVICE)
opt = optim.AdamW(model.parameters(), lr=LR)
loss_hist = []
print("Running Associative Recall with Locked Architecture...")
for i in range(STEPS):
x, y = get_batch(seq_len=32, vocab_size=VOCAB, device=DEVICE)
opt.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt.step()
loss_hist.append(loss.item())
if i % 100 == 0:
print(f"Step {i}: Loss {loss.item():.4f}")
plt.figure(figsize=(10, 6))
plt.plot(loss_hist, label="Context-Pulse (Final Locked)")
plt.title("The Context-Pulse Manifold: Final Convergence")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
if __name__ == "__main__":
final_run() import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
"""
THE PARSEVAL ATTENTION MECHANISM (Context-Pulse Architecture)
-------------------------------------------------------------
A "Manifesto in Code" derived from first principles of energy conservation,
geometric stability, and signal processing.
CORE PHILOSOPHY:
Standard Attention treats tokens as points in a static space.
Parseval Attention treats tokens as a dynamic signal flow where:
1. Energy is Conserved (Isometry/Unitary Transforms).
2. Entropy is Managed (Explicit Sink).
3. Causality is Structural (Leaky Integration vs. Masking).
4. Geometry is Symplectic (Twisting vs. Projection).
AUTHORS:
Falsey Winchnet (Autograder) & Gemini (Assistant)
November 2025
"""
# ==============================================================================
# 1. GEOMETRIC PRIMITIVES
# ==============================================================================
class RoPE(nn.Module):
"""
Rotary Positional Embeddings.
PHYSICS:
Encodes relative distance by rotating the vector space.
As a unitary transformation, it preserves the vector norm (Energy),
ensuring that position encoding does not introduce arbitrary gain.
"""
def __init__(self, dim, max_len=2048):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
"""
The Symplectic Wedge.
PHYSICS:
Applies a skew-symmetric flow (Twist) to the manifold: v_new = v(I + S).
Where S = A - A.T.
Unlike a standard linear layer which shears/scales space arbitrarily,
the Wedge induces a flow that preserves specific volume properties
(related to Symplectic Geometry). It allows the model to "curve" the
manifold to align semantic trajectories.
INITIALIZATION: "Silent Wedge" (Zeros).
We start in flat Euclidean space (Identity transform). The model learns
to twist the space only where necessary.
"""
def __init__(self, dim, heads):
super().__init__()
self.n_head = heads
self.head_dim = dim // heads
# LOCKED: Silent Init
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
B, H, T, Dh = x.shape
v = x.transpose(1, 2)
S = self.A - self.A.transpose(-1, -2) # Enforce Skew-Symmetry
flow = torch.matmul(v, S)
v_twisted = v + flow
return v_twisted.transpose(1, 2)
# ==============================================================================
# 2. NUMERICAL STABILIZERS
# ==============================================================================
class _FusedLogSumExp4D(torch.autograd.Function):
"""
Convex Stability Kernel.
PHYSICS:
Computes LogSumExp with high precision (Float32 accumulation) to avoid
underflow/overflow in the probability mass calculation.
Prevents "Gradient Starvation" where small probabilities vanish in FP16.
"""
@staticmethod
def forward(ctx, x: torch.Tensor):
m, _ = x.max(dim=-1, keepdim=True)
y = x - m
ex = y.exp()
s = ex.sum(dim=-1, keepdim=True)
lse = m + s.log()
ctx.save_for_backward(ex, s)
return lse
@staticmethod
def backward(ctx, grad_output):
ex, s = ctx.saved_tensors
grad_x = grad_output * (ex / s)
return grad_x
class ConvexSoftmax(nn.Module):
def forward(self, scores):
lse = _FusedLogSumExp4D.apply(scores)
log_weights = scores - lse
return torch.exp(log_weights)
# ==============================================================================
# 3. THE ARCHITECTURE
# ==============================================================================
class ContextPulseAttention(nn.Module):
def __init__(self, d_model, head_dim):
super().__init__()
self.d_model = d_model
self.head_dim = head_dim
self.scale = head_dim ** -0.5
# --- PROJECTIONS ---
self.W_Q = nn.Linear(d_model, head_dim, bias=False)
self.W_K = nn.Linear(d_model, head_dim, bias=False)
self.W_V = nn.Linear(d_model, head_dim, bias=False)
# --- LOCKED: ROLE-BASED INITIALIZATION ---
# Q (Context): Orthogonal. The "Vessel" must be geometric.
nn.init.orthogonal_(self.W_Q.weight)
# K (Pulse): Silent (Normal 0.02). The "Event" starts quiet to prevent noise flooding.
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
# V (Value): Xavier. Standard signal content.
nn.init.xavier_normal_(self.W_V.weight)
# --- GATING ---
self.decay_logits = nn.Parameter(torch.zeros(1, 1, 1, 1))
self.gate_proj = nn.Linear(d_model, head_dim)
# --- GEOMETRY STACK ---
self.rope = RoPE(head_dim)
self.wedge = WedgeTransform(head_dim, 1) # Demo: Single Head
self.softmax = ConvexSoftmax()
# --- SCALAR SINK (The "Isotropic Toilet") ---
# Instead of a vector requiring rotation, we use a learnable scalar threshold.
# If Q is orthogonal to K (Score < Sink Scalar), energy dumps here.
self.sink_scalar = nn.Parameter(torch.tensor(0.0))
# --- LEARNED RESET VALUE ---
# When the Sink activates, we add this specific vector to the output.
# Allows the model to "Reset" or output a specific bias when confused.
self.v_null = nn.Parameter(torch.zeros(1, 1, 1, head_dim))
def forward(self, x):
B, T, C = x.shape
H, Dh = 1, self.head_dim
# 1. Projections
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
# 2. Dynamic Gating (Leaky Integration Control)
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
decay = torch.sigmoid(self.decay_logits)
# 3. Context Accumulation (The Inversion)
# Q becomes the sum of history.
Q_context = torch.zeros_like(q_raw)
running_q = torch.zeros(B, H, Dh, device=x.device)
for t in range(T):
input_t = gate[:, :, t, :] * q_raw[:, :, t, :]
running_q = decay.squeeze(-1) * running_q + (1.0 - decay.squeeze(-1)) * input_t
Q_context[:, :, t, :] = running_q
K_pulse = k_raw
# 4. LOCKED: SYMBIOTIC INSTANCE NORM
# Forces Q and K to share a statistical coordinate system per sample.
# Prevents "Ghost of Newton" (Optimization Drift) and magnitude mismatch.
dims = (1, 2, 3) # Mean over Head, Time, Dim
mean_sym = 0.5 * (Q_context.mean(dim=dims, keepdim=True) + K_pulse.mean(dim=dims, keepdim=True))
std_sym = 0.5 * (Q_context.std(dim=dims, keepdim=True) + K_pulse.std(dim=dims, keepdim=True))
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# 5. Geometric Twist
Q_geo = self.wedge(self.rope(Q_context))
K_geo = self.wedge(self.rope(K_pulse))
#By pitting them against each other:On the Diagonal (Self): RoPE dominates.
#Identity is perfectly preserved. The Wedge cannot corrupt the token's understanding of itself.
#Off the Diagonal (Relation): The Wedge wakes up. The Wedge provides the "semantic twist"
#needed to bridge the gap, fighting against the rigid rotational distance of RoPE.
# 6. Attention Scores
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
# 7. Scalar Sink Injection
# Expand scalar to match (B, H, T, 1)
null_scores = self.sink_scalar.view(1, 1, 1, 1).expand(B, H, T, 1)
Attn_full = torch.cat([Attn, null_scores], dim=-1)
# 8. Probability Mass (LSE)
probs_full = self.softmax(Attn_full)
# 9. Value Integration
probs_seq = probs_full[..., :T] # (B, H, T, T)
out = probs_seq @ v
# 10. Sink Contribution (Reset Mechanism)
probs_sink = probs_full[..., T:] # (B, H, T, 1)
out = out + probs_sink * self.v_null
return out.transpose(1, 2).reshape(B, T, self.d_model)
```pyimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import matplotlib.pyplot as plt
# --- 1. Locked Primitives ---
class RoPE(nn.Module):
def __init__(self, dim, max_len=2048):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
def __init__(self, dim, heads):
super().__init__()
self.n_head = heads
self.head_dim = dim // heads
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
B, H, T, Dh = x.shape
v = x.transpose(1, 2)
S = self.A - self.A.transpose(-1, -2)
flow = x + torch.matmul(x, S)
return flow
class ConvexSoftmax(nn.Module):
def forward(self, scores):
m, _ = scores.max(dim=-1, keepdim=True)
y = scores - m
ex = y.exp()
lse = m + ex.sum(dim=-1, keepdim=True).log()
return torch.exp(scores - lse)
# --- 2. Multi-Head Causal Chunked Model ---
class ContextPulseAttention(nn.Module):
def __init__(self, d_model, n_head, chunk_size=16):
super().__init__()
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.chunk_size = chunk_size
self.scale = self.head_dim ** -0.5
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
nn.init.orthogonal_(self.W_Q.weight)
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
nn.init.xavier_normal_(self.W_V.weight)
self.gate_proj = nn.Linear(d_model, d_model)
self.rope = RoPE(self.head_dim)
# FIX: Pass d_model, not head_dim, because Wedge divides internally
self.wedge = WedgeTransform(d_model, n_head)
self.softmax = ConvexSoftmax()
self.sink_scalar = nn.Parameter(torch.zeros(1, n_head, 1, 1))
self.v_null = nn.Parameter(torch.zeros(1, n_head, 1, self.head_dim))
def forward(self, x):
B, T, C = x.shape
H, Dh = self.n_head, self.head_dim
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v_val = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
num_chunks = math.ceil(T / self.chunk_size)
Q_output_list = []
history_cache = []
for i in range(num_chunks):
t_start = i * self.chunk_size
t_end = min((i + 1) * self.chunk_size, T)
q_chunk = q_raw[:, :, t_start:t_end, :]
k_chunk = k_raw[:, :, t_start:t_end, :]
gate_chunk = gate[:, :, t_start:t_end, :]
if len(history_cache) > 0:
history_stack = torch.cat(history_cache, dim=2)
resonance_scores = (k_chunk @ history_stack.transpose(-2, -1)) * self.scale
resonance_weights = F.softmax(resonance_scores, dim=-1)
q_retrieved = resonance_weights @ history_stack
base_state = history_cache[-1].expand(-1, -1, k_chunk.size(2), -1)
q_initial = base_state + q_retrieved
else:
q_initial = torch.zeros_like(k_chunk)
masked_input = gate_chunk * q_chunk
q_integrated = torch.cumsum(masked_input, dim=2)
q_chunk_out = q_integrated + q_initial
Q_output_list.append(q_chunk_out)
final_state = q_chunk_out[:, :, -1:, :]
history_cache.append(final_state)
Q_context = torch.cat(Q_output_list, dim=2)
K_pulse = k_raw
# Pointwise Symbiotic Norm
mean_q = Q_context.mean(dim=-1, keepdim=True)
std_q = Q_context.std(dim=-1, keepdim=True)
mean_k = K_pulse.mean(dim=-1, keepdim=True)
std_k = K_pulse.std(dim=-1, keepdim=True)
mean_sym = 0.5 * (mean_q + mean_k)
std_sym = 0.5 * (std_q + std_k)
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# Geometry
q_roped = self.rope(Q_context) # (B, H, T, D)
k_roped = self.rope(K_pulse)
# TRANSPOSE SANDWICH for Wedge compatibility
# Input: (B, T, H, D) -> Wedge flips to (B, H, T, D) for matmul -> Returns (B, T, H, D)
Q_geo = self.wedge(q_roped)
K_geo = self.wedge(k_roped)
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
null_scores = self.sink_scalar.expand(B, H, T, 1)
Attn_full = torch.cat([Attn, null_scores], dim=-1)
probs_full = self.softmax(Attn_full)
out = probs_full[..., :T] @ v_val + probs_full[..., T:] * self.v_null
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_O(out)
# --- 3. Multi-Head Harness ---
class ToyTransformer(nn.Module):
def __init__(self, vocab_size, d_model, n_head):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)
def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
START_TOKEN = 0
STOP_TOKEN = 1
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
def run_multihead_hard():
print("--- FIXED MULTI-HEAD GAP COPY: 4 Heads, 64 Dim ---")
VOCAB = 64
DIM = 64
HEADS = 4
STEPS = 1500
SEQ_LEN = 256
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyTransformer(VOCAB, DIM, HEADS).to(DEVICE)
opt = optim.AdamW(model.parameters(), lr=LR)
loss_hist = []
for i in range(STEPS):
x, y = get_gap_copy_batch(seq_len=SEQ_LEN, vocab_size=VOCAB, device=DEVICE)
opt.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt.step()
loss_hist.append(loss.item())
if i % 100 == 0:
print(f"Step {i}: Loss {loss.item():.4f}")
plt.figure(figsize=(10, 6))
plt.plot(loss_hist)
plt.title("Multi-Head Gap Copy (Fixed)")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.show()
if __name__ == "__main__":
run_multihead_hard()
```import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math
--- 1. Primitives (Locked) ---
class RoPE(nn.Module):
def init(self, dim, max_len=2048):
super().init()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
def init(self, dim, heads):
super().init()
self.n_head = heads
self.head_dim = dim // heads
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
S = self.A - self.A.transpose(-1, -2)
S_broadcast = S.unsqueeze(0)
flow = torch.matmul(x, S_broadcast)
return x + flow
class ConvexSoftmax(nn.Module):
def forward(self, scores):
m, _ = scores.max(dim=-1, keepdim=True)
y = scores - m
ex = y.exp()
lse = m + ex.sum(dim=-1, keepdim=True).log()
return torch.exp(scores - lse)
def build_alpert_basis(block_size, poly_order=1, device='cpu'):
t = torch.linspace(-1, 1, block_size, device=device)
p0 = torch.ones_like(t)
basis_list = [p0]
if poly_order >= 1: basis_list.append(t)
if poly_order >= 2: basis_list.append(3 * t**2 - 1)
W = torch.stack(basis_list, dim=1)
W = F.normalize(W, p=2, dim=0)
return W
--- 2. Resonance-Gated Alpert Model ---
class ContextPulseAttention(nn.Module):
def init(self, d_model, n_head, chunk_size=16, poly_order=1):
super().init()
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.chunk_size = chunk_size
self.poly_order = poly_order
self.scale = self.head_dim ** -0.5
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
nn.init.orthogonal_(self.W_Q.weight)
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
nn.init.xavier_normal_(self.W_V.weight)
self.gate_proj = nn.Linear(d_model, d_model)
# REMOVED: jump_proj (Blind Jump)
# ADDED: Resonance Sensitivity (Smart Jump)
self.jump_scale = nn.Parameter(torch.tensor(1.0))
self.jump_bias = nn.Parameter(torch.tensor(0.0)) # Start neutral/conservative
self.rope = RoPE(self.head_dim)
self.wedge = WedgeTransform(d_model, n_head)
self.softmax = ConvexSoftmax()
self.sink_scalar = nn.Parameter(torch.zeros(1, n_head, 1, 1))
self.v_null = nn.Parameter(torch.zeros(1, n_head, 1, self.head_dim))
basis = build_alpert_basis(chunk_size, poly_order)
self.register_buffer('alpert_basis', basis)
def forward(self, x):
B, T, C = x.shape
H, Dh = self.n_head, self.head_dim
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v_val = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
num_chunks = math.ceil(T / self.chunk_size)
Q_output_list = []
history_coeffs_cache = []
last_chunk =torch.zeros_like(q_raw[:, :, 0:self.chunk_size, :])
W_basis = self.alpert_basis
for i in range(num_chunks):
t_start = i * self.chunk_size
t_end = min((i + 1) * self.chunk_size, T)
q_chunk = q_raw[:, :, t_start:t_end, :]
k_chunk = k_raw[:, :, t_start:t_end, :]
gate_chunk = gate[:, :, t_start:t_end, :]
current_L = q_chunk.size(2)
if current_L < self.chunk_size:
W_curr = build_alpert_basis(current_L, self.poly_order, device=x.device)
else:
W_curr = W_basis
# 1. Encode Current K (Moments)
k_coeffs = torch.einsum('bhld, lk -> bhkd', k_chunk, W_curr)
# 2. Wave Retrieval
if len(history_coeffs_cache) > 0:
history_stack = torch.cat(history_coeffs_cache, dim=2)
# Resonance (Score)
resonance_scores = torch.einsum('bhkd, bhnkd -> bhn', k_coeffs, history_stack) * self.scale
resonance_scores = resonance_scores.unsqueeze(-2) # (B, H, 1, N)
# Weights
resonance_weights = F.softmax(resonance_scores, dim=-1)
# Retrieve Moments
coeffs_retrieved = torch.einsum('bhn, bhnkd -> bhkd', resonance_weights.squeeze(2), history_stack)
# Reconstruct Trajectory
q_reconstructed = torch.einsum('bhkd, lk -> bhld', coeffs_retrieved, W_curr)
# Additive Injection
if 'last_state' in locals():
continuity_stream = last_state.expand(-1, -1, current_L, -1)
q_injected = continuity_stream + q_reconstructed
else:
q_injected = q_reconstructed
else:
q_injected = torch.zeros_like(q_chunk)
# 3. Integration
q_integrated = torch.cumsum(q_chunk, dim=2)
q_chunk_out = q_integrated + q_injected - last_chunk
Q_output_list.append(q_chunk_out)
# 4. Update Cache
q_out_coeffs = torch.einsum('bhld, lk -> bhkd', q_chunk_out, W_curr)
history_coeffs_cache.append(q_out_coeffs.unsqueeze(2))
last_state = q_chunk_out[:, :, -1:, :]
last_chunk = q_injected[:, :, -1:, :]
Q_context = torch.cat(Q_output_list, dim=2)
K_pulse = k_raw
# Pointwise Norm
mean_q = Q_context.mean(dim=-1, keepdim=True)
std_q = Q_context.std(dim=-1, keepdim=True)
mean_k = K_pulse.mean(dim=-1, keepdim=True)
std_k = K_pulse.std(dim=-1, keepdim=True)
mean_sym = 0.5 * (mean_q + mean_k)
std_sym = 0.5 * (std_q + std_k)
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# Geometry
Q_geo = self.wedge(Q_context)
K_geo = self.wedge(K_pulse)
Q_geo = self.rope(Q_geo)
K_geo = self.rope(K_geo)
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
null_scores = self.sink_scalar.expand(B, H, T, 1)
Attn_full = torch.cat([Attn, null_scores], dim=-1)
probs_full = self.softmax(Attn_full)
out = probs_full[..., :T] @ v_val + probs_full[..., T:] * self.v_null
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_O(out)
--- Harness ---
class ToyTransformer(nn.Module):
def init(self, vocab_size, d_model, n_head):
super().init()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)
def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
START_TOKEN = 0
STOP_TOKEN = 1
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
def run_resonant_jump():
print("--- ALPERT + RESONANCE-DRIVEN JUMP: 4 Heads, 64 Dim ---")
VOCAB = 64
DIM = 128
HEADS = 4
STEPS = 1500
SEQ_LEN = 256
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyTransformer(VOCAB, DIM, HEADS).to(DEVICE)
opt = optim.AdamW(model.parameters(), lr=LR)
loss_hist = []
for i in range(STEPS):
x, y = get_gap_copy_batch(seq_len=SEQ_LEN, vocab_size=VOCAB, device=DEVICE)
opt.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt.step()
loss_hist.append(loss.item())
if i % 100 == 0:
print(f"Step {i}: Loss {loss.item():.4f}")
plt.figure(figsize=(10, 6))
plt.plot(loss_hist)
plt.title("Resonance-Gated Alpert Retrieval")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.show()
if name == "main":
run_resonant_jump()

A score of 2.0 on this problem does not mean the model is "okay" or "average." It means the model has achieved Perfect Recall (0.0 loss) on the second half. It is mathematically impossible to go lower than ~2.08 because you cannot reduce the loss of the random half (unless the random number generator is broken or the model is overfitting the seed).