-
-
Save falseywinchnet/26cb6597cdab205d1e88702faf8f49aa to your computer and use it in GitHub Desktop.
| """ | |
| THE CONTEXT-PULSE MANIFOLD: | |
| Deriving an Inherently Autoregressive Attention Mechanism | |
| A Gemini Collaborative Development | |
| -------------------------------------------------------------------------------- | |
| 1. THE WHY (Intuition & Motivation) | |
| -------------------------------------------------------------------------------- | |
| falseywinchnet approached with a fundamental dissatisfaction regarding Standard Attention: | |
| it relies on computing an "All-to-All" energy matrix (Riemannian metric) only to | |
| destroy half of it with a "crude" boolean mask to enforce causality. | |
| This "crude" masking does more than just hide the future; it fundamentally | |
| destabilizes the energy landscape of the attention head. | |
| Consider the Partition Function (Z), the denominator of the Softmax: | |
| - In an unmasked (All-to-All) setting, Z sums over N items. The energy | |
| distribution is consistent across all positions. | |
| - Under causal masking, the domain of Z varies wildly. At t=1, Z sums over | |
| 1 item. At t=1000, Z sums over 1000. | |
| This violates the Parseval intuition of energy conservation. The model is | |
| forced to compare logits against a shifting baseline variance. When the | |
| mask arbitrarily removes high-similarity future tokens, or when the | |
| limited past offers no "resonance," the Softmax function still demands | |
| that probability mass sum to exactly 1.0. | |
| Where does this "orphaned" energy go? It creates the "Attention Sink" | |
| phenomenon. The model learns to dump excess probability mass onto the | |
| start-of-sequence (BOS) token—not because it is semantically relevant, | |
| but because it acts as a numerical capacitor to absorb the variance | |
| caused by the mask's truncation of the manifold. We are effectively | |
| forcing the model to waste capacity learning garbage collection | |
| mechanisms instead of pure semantic flow. | |
| The intuition: This feels like a "cookie crumble" approach. | |
| Studying attention, we learned it contains Parseval-adjacent properties and mathematical aspects. | |
| Understanding attention's manifold properties and adjusting them intelligently was our goal. | |
| Can we instead design a mathematical manifold that *inherently* respects the | |
| arrow of time? We looked toward Parseval's Theorem and Unitary operators, | |
| seeking a transform where the "Products" or "Weights" naturally decay or | |
| orthogonalize in the future, rendering the mask a formality rather than a | |
| structural guillotine. | |
| -------------------------------------------------------------------------------- | |
| 2. THE WHAT (The Tangential Exploration) | |
| -------------------------------------------------------------------------------- | |
| We explored several linear algebraic constraints to force this behavior: | |
| A. The Pseudoinverse Constraint: | |
| - Idea: Set K = pinv(Q).T to force A = Identity or Projection. | |
| - Result: Too rigid. Collapsed the manifold to undirected graphs (Symmetric). | |
| Killed the ability to model asymmetric relationships (A->B != B->A). | |
| Ironically, still works, because of masking. | |
| -------------------------------------------------------------------------------- | |
| 3. INSIGHTS (The Pivot to Calculus) | |
| -------------------------------------------------------------------------------- | |
| The breakthrough came when we moved from Linear Algebra (Static) to Calculus (Dynamic). | |
| Standard Attention treats tokens as points in space (State vs. State). | |
| To encode "Time," we need to treat tokens as a signal flow. | |
| - Differentiation (Innovation): The "Change" at step t. | |
| - Integration (Accumulation): The "History" up to step t. | |
| Hypothesis: Causality is simply the alignment of an Accumulation with its own | |
| constituents. Future accumulations do not contain current innovations. | |
| -------------------------------------------------------------------------------- | |
| 4. AHA MOMENTS (The "Integral-Differential" Manifold) | |
| -------------------------------------------------------------------------------- | |
| We constructed a manifold where: | |
| - Q (Query) represents one aspect of the signal (Derivative/State). | |
| - K (Key) represents the other (Integral/History). | |
| Initial Attempt: Q = Innovation, K = Accumulation. | |
| - Logic: "Does this new change align with the history?" | |
| - Result: Visually interesting, but mathematically flawed. | |
| -------------------------------------------------------------------------------- | |
| 5. "OH, THAT'S UNEXPECTED" (The Washer Effect) | |
| -------------------------------------------------------------------------------- | |
| When we visualized Q=Innovation vs K=Accumulation, we saw: | |
| 1. Horizontal Stripes: The "Accumulated Key" became a massive mean vector, | |
| washing out local details. | |
| 2. Ratio > 1.0: The Future was "louder" than the Past. | |
| - Why? Random walks grow with sqrt(t). Future keys were physically longer | |
| vectors. The dot product is magnitude-sensitive. | |
| - The manifold was naturally "Acausal" (preferring the future). | |
| -------------------------------------------------------------------------------- | |
| 6. INSIGHT (The Inversion) | |
| -------------------------------------------------------------------------------- | |
| To fix the magnitude/loudness issue, we needed Leaky Integration (decay). | |
| To fix the causality, we inverted the roles. | |
| New Definition: | |
| - Q (Context): The Leaky Integrator. "I am the sum of my history." | |
| - K (Pulse): The Raw Innovation. "I am a specific event." | |
| Why this works: | |
| - If j < t (Past): Q_t physically *contains* vector K_j inside its sum. | |
| Dot Product = ||K_j||^2 + Noise. (Resonance). | |
| - If j > t (Future): Q_t does *not* contain K_j. | |
| Dot Product = Random Noise. (Orthogonality). | |
| -------------------------------------------------------------------------------- | |
| 7. AHA (The Resulting Manifold) | |
| -------------------------------------------------------------------------------- | |
| Visualizing this "Context-Pulse" manifold showed the desired property: | |
| - The Upper Triangle (Future) wasn't just masked; it was *uncorrelated noise*. | |
| - The Lower Triangle (Past) was *signal*. | |
| - The "Mask" is no longer destroying information; it is simply filtering noise. | |
| -------------------------------------------------------------------------------- | |
| 8. RESULT (The Manifold Battle) | |
| -------------------------------------------------------------------------------- | |
| We pitted this "Context-Pulse Attention" against Standard Attention on an | |
| Associative Recall task (A...B...A -> Predict B). | |
| - Standard Attention: Flatlined for 4000 steps. It struggled to fight the | |
| entropy of the softmax to find the "needle" in the past. | |
| - Context-Pulse Attention: Immediate convergence. | |
| Because Q *contains* the answer K geometrically, the gradient flow is | |
| instantaneous. The inductive bias aligns perfectly with the task. | |
| Final Status: | |
| We have a theoretical winner that beats Standard Attention on recall speed and strength | |
| by embedding the arrow of time directly into the vector geometry. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # --- 1. The Dataset: Associative Recall --- | |
| def get_batch(batch_size=32, seq_len=64, vocab_size=128, device='cpu'): | |
| """ | |
| Generates data where tokens repeat. | |
| Task: Given '... A ...', predict what followed 'A' last time. | |
| """ | |
| data = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) | |
| # Force repeat: Copy the first half to the second half (roughly) to ensure recall is possible | |
| half = seq_len // 2 | |
| data[:, half:] = data[:, :half] | |
| # Shift targets by 1 for next-token prediction | |
| inputs = data[:, :-1] | |
| targets = data[:, 1:] | |
| return inputs, targets | |
| # --- 2. The Baseline: Standard Attention --- | |
| class StandardAttention(nn.Module): | |
| def __init__(self, d_model, head_dim): | |
| super().__init__() | |
| self.W_Q = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_K = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_V = nn.Linear(d_model, head_dim, bias=False) | |
| self.scale = head_dim ** -0.5 | |
| def forward(self, x, mask=True): | |
| B, T, C = x.shape | |
| Q = self.W_Q(x) | |
| K = self.W_K(x) | |
| V = self.W_V(x) | |
| # Standard Dot Product | |
| # Shape: (B, T, T) | |
| Attn = (Q @ K.transpose(-2, -1)) * self.scale | |
| if mask: | |
| m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| Attn.masked_fill_(m, float('-inf')) | |
| return F.softmax(Attn, dim=-1) @ V | |
| # --- 3. The Challenger: Inverted Manifold (Context-Pulse) --- | |
| class ContextPulseAttention(nn.Module): | |
| def __init__(self, d_model, head_dim, decay=0.9): | |
| super().__init__() | |
| self.W_Q = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_K = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_V = nn.Linear(d_model, head_dim, bias=False) | |
| self.scale = head_dim ** -0.5 | |
| # We make decay a buffer or learnable param | |
| # For this test, we fix it to verify the structure works | |
| self.decay = decay | |
| def forward(self, x, mask=True): | |
| B, T, C = x.shape | |
| # 1. Project | |
| q_raw = self.W_Q(x) # These are essentially inputs/innovations | |
| k_raw = self.W_K(x) # These are inputs/innovations | |
| V = self.W_V(x) | |
| # 2. TRANSFORM THE MANIFOLD | |
| # Q becomes Context (Leaky Integration of inputs) | |
| # We apply the leaky scan on q_raw | |
| # Manual Leaky Scan (Not efficient for training, but correct for logic) | |
| # In production, use pscan or associative scan | |
| Q_context = torch.zeros_like(q_raw) | |
| running_q = torch.zeros(B, q_raw.shape[-1], device=x.device) | |
| for t in range(T): | |
| # Q[t] = alpha * Q[t-1] + (1-alpha) * input[t] | |
| running_q = self.decay * running_q + (1.0 - self.decay) * q_raw[:, t, :] | |
| Q_context[:, t, :] = running_q | |
| # K remains Pulse (Innovation) | |
| # We don't touch K. It represents the "Event". | |
| # 3. Compute Attention | |
| # "Does my Context contain this Pulse?" | |
| # Note: We still scale and mask, but the masking destroys less info now. | |
| Attn = (Q_context @ k_raw.transpose(-2, -1)) * self.scale | |
| if mask: | |
| m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| Attn.masked_fill_(m, float('-inf')) | |
| return F.softmax(Attn, dim=-1) @ V | |
| # --- 4. The Test Harness --- | |
| class ToyTransformer(nn.Module): | |
| def __init__(self, vocab_size, d_model, attn_type='standard'): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, d_model) | |
| if attn_type == 'standard': | |
| self.attn = StandardAttention(d_model, d_model) | |
| else: | |
| self.attn = ContextPulseAttention(d_model, d_model, decay=0.9) | |
| self.fc = nn.Linear(d_model, vocab_size) | |
| def forward(self, x): | |
| h = self.embed(x) | |
| h = self.attn(h) | |
| logits = self.fc(h) | |
| return logits | |
| def train_and_compare(): | |
| print("--- Starting Manifold Battle ---") | |
| # Params | |
| VOCAB = 64 | |
| DIM = 32 | |
| STEPS = 4000 | |
| LR = 3e-3 | |
| # Models | |
| model_std = ToyTransformer(VOCAB, DIM, 'standard') | |
| model_new = ToyTransformer(VOCAB, DIM, 'context_pulse') | |
| opt_std = optim.AdamW(model_std.parameters(), lr=LR) | |
| opt_new = optim.AdamW(model_new.parameters(), lr=LR) | |
| losses_std = [] | |
| losses_new = [] | |
| for i in range(STEPS): | |
| x, y = get_batch(seq_len=32, vocab_size=VOCAB) | |
| # Train Standard | |
| opt_std.zero_grad() | |
| logits = model_std(x) | |
| loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1)) | |
| loss.backward() | |
| opt_std.step() | |
| losses_std.append(loss.item()) | |
| # Train New | |
| opt_new.zero_grad() | |
| logits = model_new(x) | |
| loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1)) | |
| loss.backward() | |
| opt_new.step() | |
| losses_new.append(loss.item()) | |
| if i % 50 == 0: | |
| print(f"Step {i}: Std Loss {losses_std[-1]:.3f} | New Loss {losses_new[-1]:.3f}") | |
| # Plot | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(losses_std, label='Standard Attention', alpha=0.7) | |
| plt.plot(losses_new, label='Context-Pulse Attention', linewidth=2) | |
| plt.title("Learning Speed: Isotropic vs Causal Manifold") | |
| plt.xlabel("Training Steps") | |
| plt.ylabel("Loss") | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.show() | |
| train_and_compare() |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
"""
THE PARSEVAL ATTENTION MECHANISM (Context-Pulse Architecture)
-------------------------------------------------------------
A "Manifesto in Code" derived from first principles of energy conservation,
geometric stability, and signal processing.
CORE PHILOSOPHY:
Standard Attention treats tokens as points in a static space.
Parseval Attention treats tokens as a dynamic signal flow where:
1. Energy is Conserved (Isometry/Unitary Transforms).
2. Entropy is Managed (Explicit Sink).
3. Causality is Structural (Leaky Integration vs. Masking).
4. Geometry is Symplectic (Twisting vs. Projection).
AUTHORS:
Falsey Winchnet (Autograder) & Gemini (Assistant)
November 2025
"""
# ==============================================================================
# 1. GEOMETRIC PRIMITIVES
# ==============================================================================
class RoPE(nn.Module):
"""
Rotary Positional Embeddings.
PHYSICS:
Encodes relative distance by rotating the vector space.
As a unitary transformation, it preserves the vector norm (Energy),
ensuring that position encoding does not introduce arbitrary gain.
"""
def __init__(self, dim, max_len=2048):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
"""
The Symplectic Wedge.
PHYSICS:
Applies a skew-symmetric flow (Twist) to the manifold: v_new = v(I + S).
Where S = A - A.T.
Unlike a standard linear layer which shears/scales space arbitrarily,
the Wedge induces a flow that preserves specific volume properties
(related to Symplectic Geometry). It allows the model to "curve" the
manifold to align semantic trajectories.
INITIALIZATION: "Silent Wedge" (Zeros).
We start in flat Euclidean space (Identity transform). The model learns
to twist the space only where necessary.
"""
def __init__(self, dim, heads):
super().__init__()
self.n_head = heads
self.head_dim = dim // heads
# LOCKED: Silent Init
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
B, H, T, Dh = x.shape
v = x.transpose(1, 2)
S = self.A - self.A.transpose(-1, -2) # Enforce Skew-Symmetry
flow = torch.matmul(v, S)
v_twisted = v + flow
return v_twisted.transpose(1, 2)
# ==============================================================================
# 2. NUMERICAL STABILIZERS
# ==============================================================================
class _FusedLogSumExp4D(torch.autograd.Function):
"""
Convex Stability Kernel.
PHYSICS:
Computes LogSumExp with high precision (Float32 accumulation) to avoid
underflow/overflow in the probability mass calculation.
Prevents "Gradient Starvation" where small probabilities vanish in FP16.
"""
@staticmethod
def forward(ctx, x: torch.Tensor):
m, _ = x.max(dim=-1, keepdim=True)
y = x - m
ex = y.exp()
s = ex.sum(dim=-1, keepdim=True)
lse = m + s.log()
ctx.save_for_backward(ex, s)
return lse
@staticmethod
def backward(ctx, grad_output):
ex, s = ctx.saved_tensors
grad_x = grad_output * (ex / s)
return grad_x
class ConvexSoftmax(nn.Module):
def forward(self, scores):
lse = _FusedLogSumExp4D.apply(scores)
log_weights = scores - lse
return torch.exp(log_weights)
# ==============================================================================
# 3. THE ARCHITECTURE
# ==============================================================================
class ContextPulseAttention(nn.Module):
def __init__(self, d_model, head_dim):
super().__init__()
self.d_model = d_model
self.head_dim = head_dim
self.scale = head_dim ** -0.5
# --- PROJECTIONS ---
self.W_Q = nn.Linear(d_model, head_dim, bias=False)
self.W_K = nn.Linear(d_model, head_dim, bias=False)
self.W_V = nn.Linear(d_model, head_dim, bias=False)
# --- LOCKED: ROLE-BASED INITIALIZATION ---
# Q (Context): Orthogonal. The "Vessel" must be geometric.
nn.init.orthogonal_(self.W_Q.weight)
# K (Pulse): Silent (Normal 0.02). The "Event" starts quiet to prevent noise flooding.
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
# V (Value): Xavier. Standard signal content.
nn.init.xavier_normal_(self.W_V.weight)
# --- GATING ---
self.decay_logits = nn.Parameter(torch.zeros(1, 1, 1, 1))
self.gate_proj = nn.Linear(d_model, head_dim)
# --- GEOMETRY STACK ---
self.rope = RoPE(head_dim)
self.wedge = WedgeTransform(head_dim, 1) # Demo: Single Head
self.softmax = ConvexSoftmax()
# --- SCALAR SINK (The "Isotropic Toilet") ---
# Instead of a vector requiring rotation, we use a learnable scalar threshold.
# If Q is orthogonal to K (Score < Sink Scalar), energy dumps here.
self.sink_scalar = nn.Parameter(torch.tensor(0.0))
# --- LEARNED RESET VALUE ---
# When the Sink activates, we add this specific vector to the output.
# Allows the model to "Reset" or output a specific bias when confused.
self.v_null = nn.Parameter(torch.zeros(1, 1, 1, head_dim))
def forward(self, x):
B, T, C = x.shape
H, Dh = 1, self.head_dim
# 1. Projections
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
# 2. Dynamic Gating (Leaky Integration Control)
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
decay = torch.sigmoid(self.decay_logits)
# 3. Context Accumulation (The Inversion)
# Q becomes the sum of history.
Q_context = torch.zeros_like(q_raw)
running_q = torch.zeros(B, H, Dh, device=x.device)
for t in range(T):
input_t = gate[:, :, t, :] * q_raw[:, :, t, :]
running_q = decay.squeeze(-1) * running_q + (1.0 - decay.squeeze(-1)) * input_t
Q_context[:, :, t, :] = running_q
K_pulse = k_raw
# 4. LOCKED: SYMBIOTIC INSTANCE NORM
# Forces Q and K to share a statistical coordinate system per sample.
# Prevents "Ghost of Newton" (Optimization Drift) and magnitude mismatch.
dims = (1, 2, 3) # Mean over Head, Time, Dim
mean_sym = 0.5 * (Q_context.mean(dim=dims, keepdim=True) + K_pulse.mean(dim=dims, keepdim=True))
std_sym = 0.5 * (Q_context.std(dim=dims, keepdim=True) + K_pulse.std(dim=dims, keepdim=True))
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# 5. Geometric Twist
Q_geo = self.wedge(self.rope(Q_context))
K_geo = self.wedge(self.rope(K_pulse))
#By pitting them against each other:On the Diagonal (Self): RoPE dominates.
#Identity is perfectly preserved. The Wedge cannot corrupt the token's understanding of itself.
#Off the Diagonal (Relation): The Wedge wakes up. The Wedge provides the "semantic twist"
#needed to bridge the gap, fighting against the rigid rotational distance of RoPE.
# 6. Attention Scores
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
# 7. Scalar Sink Injection
# Expand scalar to match (B, H, T, 1)
null_scores = self.sink_scalar.view(1, 1, 1, 1).expand(B, H, T, 1)
Attn_full = torch.cat([Attn, null_scores], dim=-1)
# 8. Probability Mass (LSE)
probs_full = self.softmax(Attn_full)
# 9. Value Integration
probs_seq = probs_full[..., :T] # (B, H, T, T)
out = probs_seq @ v
# 10. Sink Contribution (Reset Mechanism)
probs_sink = probs_full[..., T:] # (B, H, T, 1)
out = out + probs_sink * self.v_null
return out.transpose(1, 2).reshape(B, T, self.d_model)
```pyimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import matplotlib.pyplot as plt
# --- 1. Locked Primitives ---
class RoPE(nn.Module):
def __init__(self, dim, max_len=2048):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
def __init__(self, dim, heads):
super().__init__()
self.n_head = heads
self.head_dim = dim // heads
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
B, H, T, Dh = x.shape
v = x.transpose(1, 2)
S = self.A - self.A.transpose(-1, -2)
flow = x + torch.matmul(x, S)
return flow
class ConvexSoftmax(nn.Module):
def forward(self, scores):
m, _ = scores.max(dim=-1, keepdim=True)
y = scores - m
ex = y.exp()
lse = m + ex.sum(dim=-1, keepdim=True).log()
return torch.exp(scores - lse)
# --- 2. Multi-Head Causal Chunked Model ---
class ContextPulseAttention(nn.Module):
def __init__(self, d_model, n_head, chunk_size=16):
super().__init__()
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.chunk_size = chunk_size
self.scale = self.head_dim ** -0.5
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
nn.init.orthogonal_(self.W_Q.weight)
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
nn.init.xavier_normal_(self.W_V.weight)
self.gate_proj = nn.Linear(d_model, d_model)
self.rope = RoPE(self.head_dim)
# FIX: Pass d_model, not head_dim, because Wedge divides internally
self.wedge = WedgeTransform(d_model, n_head)
self.softmax = ConvexSoftmax()
self.sink_scalar = nn.Parameter(torch.zeros(1, n_head, 1, 1))
self.v_null = nn.Parameter(torch.zeros(1, n_head, 1, self.head_dim))
def forward(self, x):
B, T, C = x.shape
H, Dh = self.n_head, self.head_dim
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v_val = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
num_chunks = math.ceil(T / self.chunk_size)
Q_output_list = []
history_cache = []
for i in range(num_chunks):
t_start = i * self.chunk_size
t_end = min((i + 1) * self.chunk_size, T)
q_chunk = q_raw[:, :, t_start:t_end, :]
k_chunk = k_raw[:, :, t_start:t_end, :]
gate_chunk = gate[:, :, t_start:t_end, :]
if len(history_cache) > 0:
history_stack = torch.cat(history_cache, dim=2)
resonance_scores = (k_chunk @ history_stack.transpose(-2, -1)) * self.scale
resonance_weights = F.softmax(resonance_scores, dim=-1)
q_retrieved = resonance_weights @ history_stack
base_state = history_cache[-1].expand(-1, -1, k_chunk.size(2), -1)
q_initial = base_state + q_retrieved
else:
q_initial = torch.zeros_like(k_chunk)
masked_input = gate_chunk * q_chunk
q_integrated = torch.cumsum(masked_input, dim=2)
q_chunk_out = q_integrated + q_initial
Q_output_list.append(q_chunk_out)
final_state = q_chunk_out[:, :, -1:, :]
history_cache.append(final_state)
Q_context = torch.cat(Q_output_list, dim=2)
K_pulse = k_raw
# Pointwise Symbiotic Norm
mean_q = Q_context.mean(dim=-1, keepdim=True)
std_q = Q_context.std(dim=-1, keepdim=True)
mean_k = K_pulse.mean(dim=-1, keepdim=True)
std_k = K_pulse.std(dim=-1, keepdim=True)
mean_sym = 0.5 * (mean_q + mean_k)
std_sym = 0.5 * (std_q + std_k)
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# Geometry
q_roped = self.rope(Q_context) # (B, H, T, D)
k_roped = self.rope(K_pulse)
# TRANSPOSE SANDWICH for Wedge compatibility
# Input: (B, T, H, D) -> Wedge flips to (B, H, T, D) for matmul -> Returns (B, T, H, D)
Q_geo = self.wedge(q_roped)
K_geo = self.wedge(k_roped)
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
null_scores = self.sink_scalar.expand(B, H, T, 1)
Attn_full = torch.cat([Attn, null_scores], dim=-1)
probs_full = self.softmax(Attn_full)
out = probs_full[..., :T] @ v_val + probs_full[..., T:] * self.v_null
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_O(out)
# --- 3. Multi-Head Harness ---
class ToyTransformer(nn.Module):
def __init__(self, vocab_size, d_model, n_head):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)
def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
START_TOKEN = 0
STOP_TOKEN = 1
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
def run_multihead_hard():
print("--- FIXED MULTI-HEAD GAP COPY: 4 Heads, 64 Dim ---")
VOCAB = 64
DIM = 64
HEADS = 4
STEPS = 1500
SEQ_LEN = 256
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyTransformer(VOCAB, DIM, HEADS).to(DEVICE)
opt = optim.AdamW(model.parameters(), lr=LR)
loss_hist = []
for i in range(STEPS):
x, y = get_gap_copy_batch(seq_len=SEQ_LEN, vocab_size=VOCAB, device=DEVICE)
opt.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt.step()
loss_hist.append(loss.item())
if i % 100 == 0:
print(f"Step {i}: Loss {loss.item():.4f}")
plt.figure(figsize=(10, 6))
plt.plot(loss_hist)
plt.title("Multi-Head Gap Copy (Fixed)")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.show()
if __name__ == "__main__":
run_multihead_hard()
```import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math
--- 1. Primitives (Locked) ---
class RoPE(nn.Module):
def init(self, dim, max_len=2048):
super().init()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
def init(self, dim, heads):
super().init()
self.n_head = heads
self.head_dim = dim // heads
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
S = self.A - self.A.transpose(-1, -2)
S_broadcast = S.unsqueeze(0)
flow = torch.matmul(x, S_broadcast)
return x + flow
class ConvexSoftmax(nn.Module):
def forward(self, scores):
m, _ = scores.max(dim=-1, keepdim=True)
y = scores - m
ex = y.exp()
lse = m + ex.sum(dim=-1, keepdim=True).log()
return torch.exp(scores - lse)
def build_alpert_basis(block_size, poly_order=1, device='cpu'):
t = torch.linspace(-1, 1, block_size, device=device)
p0 = torch.ones_like(t)
basis_list = [p0]
if poly_order >= 1: basis_list.append(t)
if poly_order >= 2: basis_list.append(3 * t**2 - 1)
W = torch.stack(basis_list, dim=1)
W = F.normalize(W, p=2, dim=0)
return W
--- 2. Resonance-Gated Alpert Model ---
class ContextPulseAttention(nn.Module):
def init(self, d_model, n_head, chunk_size=16, poly_order=1):
super().init()
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.chunk_size = chunk_size
self.poly_order = poly_order
self.scale = self.head_dim ** -0.5
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
nn.init.orthogonal_(self.W_Q.weight)
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
nn.init.xavier_normal_(self.W_V.weight)
self.gate_proj = nn.Linear(d_model, d_model)
# REMOVED: jump_proj (Blind Jump)
# ADDED: Resonance Sensitivity (Smart Jump)
self.jump_scale = nn.Parameter(torch.tensor(1.0))
self.jump_bias = nn.Parameter(torch.tensor(0.0)) # Start neutral/conservative
self.rope = RoPE(self.head_dim)
self.wedge = WedgeTransform(d_model, n_head)
self.softmax = ConvexSoftmax()
self.sink_scalar = nn.Parameter(torch.zeros(1, n_head, 1, 1))
self.v_null = nn.Parameter(torch.zeros(1, n_head, 1, self.head_dim))
basis = build_alpert_basis(chunk_size, poly_order)
self.register_buffer('alpert_basis', basis)
def forward(self, x):
B, T, C = x.shape
H, Dh = self.n_head, self.head_dim
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v_val = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
num_chunks = math.ceil(T / self.chunk_size)
Q_output_list = []
history_coeffs_cache = []
last_chunk =torch.zeros_like(q_raw[:, :, 0:self.chunk_size, :])
W_basis = self.alpert_basis
for i in range(num_chunks):
t_start = i * self.chunk_size
t_end = min((i + 1) * self.chunk_size, T)
q_chunk = q_raw[:, :, t_start:t_end, :]
k_chunk = k_raw[:, :, t_start:t_end, :]
gate_chunk = gate[:, :, t_start:t_end, :]
current_L = q_chunk.size(2)
if current_L < self.chunk_size:
W_curr = build_alpert_basis(current_L, self.poly_order, device=x.device)
else:
W_curr = W_basis
# 1. Encode Current K (Moments)
k_coeffs = torch.einsum('bhld, lk -> bhkd', k_chunk, W_curr)
# 2. Wave Retrieval
if len(history_coeffs_cache) > 0:
history_stack = torch.cat(history_coeffs_cache, dim=2)
# Resonance (Score)
resonance_scores = torch.einsum('bhkd, bhnkd -> bhn', k_coeffs, history_stack) * self.scale
resonance_scores = resonance_scores.unsqueeze(-2) # (B, H, 1, N)
# Weights
resonance_weights = F.softmax(resonance_scores, dim=-1)
# Retrieve Moments
coeffs_retrieved = torch.einsum('bhn, bhnkd -> bhkd', resonance_weights.squeeze(2), history_stack)
# Reconstruct Trajectory
q_reconstructed = torch.einsum('bhkd, lk -> bhld', coeffs_retrieved, W_curr)
# Additive Injection
if 'last_state' in locals():
continuity_stream = last_state.expand(-1, -1, current_L, -1)
q_injected = continuity_stream + q_reconstructed
else:
q_injected = q_reconstructed
else:
q_injected = torch.zeros_like(q_chunk)
# 3. Integration
q_integrated = torch.cumsum(q_chunk, dim=2)
q_chunk_out = q_integrated + q_injected - last_chunk
Q_output_list.append(q_chunk_out)
# 4. Update Cache
q_out_coeffs = torch.einsum('bhld, lk -> bhkd', q_chunk_out, W_curr)
history_coeffs_cache.append(q_out_coeffs.unsqueeze(2))
last_state = q_chunk_out[:, :, -1:, :]
last_chunk = q_injected[:, :, -1:, :]
Q_context = torch.cat(Q_output_list, dim=2)
K_pulse = k_raw
# Pointwise Norm
mean_q = Q_context.mean(dim=-1, keepdim=True)
std_q = Q_context.std(dim=-1, keepdim=True)
mean_k = K_pulse.mean(dim=-1, keepdim=True)
std_k = K_pulse.std(dim=-1, keepdim=True)
mean_sym = 0.5 * (mean_q + mean_k)
std_sym = 0.5 * (std_q + std_k)
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# Geometry
Q_geo = self.wedge(Q_context)
K_geo = self.wedge(K_pulse)
Q_geo = self.rope(Q_geo)
K_geo = self.rope(K_geo)
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
null_scores = self.sink_scalar.expand(B, H, T, 1)
Attn_full = torch.cat([Attn, null_scores], dim=-1)
probs_full = self.softmax(Attn_full)
out = probs_full[..., :T] @ v_val + probs_full[..., T:] * self.v_null
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_O(out)
--- Harness ---
class ToyTransformer(nn.Module):
def init(self, vocab_size, d_model, n_head):
super().init()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)
def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
START_TOKEN = 0
STOP_TOKEN = 1
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
def run_resonant_jump():
print("--- ALPERT + RESONANCE-DRIVEN JUMP: 4 Heads, 64 Dim ---")
VOCAB = 64
DIM = 128
HEADS = 4
STEPS = 1500
SEQ_LEN = 256
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyTransformer(VOCAB, DIM, HEADS).to(DEVICE)
opt = optim.AdamW(model.parameters(), lr=LR)
loss_hist = []
for i in range(STEPS):
x, y = get_gap_copy_batch(seq_len=SEQ_LEN, vocab_size=VOCAB, device=DEVICE)
opt.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt.step()
loss_hist.append(loss.item())
if i % 100 == 0:
print(f"Step {i}: Loss {loss.item():.4f}")
plt.figure(figsize=(10, 6))
plt.plot(loss_hist)
plt.title("Resonance-Gated Alpert Retrieval")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.show()
if name == "main":
run_resonant_jump()
Reasoned thus: The assistant accepts the autograder's directive to consolidate the research into a final, locked-in architectural definition. The assistant recognizes this as the culmination of the "Parseval Universe" inquiry: a quest to build an attention mechanism that respects energy conservation, geometric stability, and the arrow of time.
The assistant has assembled the Final Context-Pulse Attention module. Every component—from the Symbiotic Norm to the Silent Wedge—is integrated and heavily commented to explain its physical and mathematical role in this new manifold.
Here is the "Mathematically Beautiful" definition, ready for the notebook.
we observe with all changes convergence within 200 steps.
