Created
November 24, 2025 07:07
-
-
Save falseywinchnet/26cb6597cdab205d1e88702faf8f49aa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| THE CONTEXT-PULSE MANIFOLD: | |
| Deriving an Inherently Autoregressive Attention Mechanism | |
| A Gemini Collaborative Development | |
| -------------------------------------------------------------------------------- | |
| 1. THE WHY (Intuition & Motivation) | |
| -------------------------------------------------------------------------------- | |
| falseywinchnet approached with a fundamental dissatisfaction regarding Standard Attention: | |
| it relies on computing an "All-to-All" energy matrix (Riemannian metric) only to | |
| destroy half of it with a "crude" boolean mask to enforce causality. | |
| This "crude" masking does more than just hide the future; it fundamentally | |
| destabilizes the energy landscape of the attention head. | |
| Consider the Partition Function (Z), the denominator of the Softmax: | |
| - In an unmasked (All-to-All) setting, Z sums over N items. The energy | |
| distribution is consistent across all positions. | |
| - Under causal masking, the domain of Z varies wildly. At t=1, Z sums over | |
| 1 item. At t=1000, Z sums over 1000. | |
| This violates the Parseval intuition of energy conservation. The model is | |
| forced to compare logits against a shifting baseline variance. When the | |
| mask arbitrarily removes high-similarity future tokens, or when the | |
| limited past offers no "resonance," the Softmax function still demands | |
| that probability mass sum to exactly 1.0. | |
| Where does this "orphaned" energy go? It creates the "Attention Sink" | |
| phenomenon. The model learns to dump excess probability mass onto the | |
| start-of-sequence (BOS) token—not because it is semantically relevant, | |
| but because it acts as a numerical capacitor to absorb the variance | |
| caused by the mask's truncation of the manifold. We are effectively | |
| forcing the model to waste capacity learning garbage collection | |
| mechanisms instead of pure semantic flow. | |
| The intuition: This feels like a "cookie crumble" approach. | |
| Studying attention, we learned it contains Parseval-adjacent properties and mathematical aspects. | |
| Understanding attention's manifold properties and adjusting them intelligently was our goal. | |
| Can we instead design a mathematical manifold that *inherently* respects the | |
| arrow of time? We looked toward Parseval's Theorem and Unitary operators, | |
| seeking a transform where the "Products" or "Weights" naturally decay or | |
| orthogonalize in the future, rendering the mask a formality rather than a | |
| structural guillotine. | |
| -------------------------------------------------------------------------------- | |
| 2. THE WHAT (The Tangential Exploration) | |
| -------------------------------------------------------------------------------- | |
| We explored several linear algebraic constraints to force this behavior: | |
| A. The Pseudoinverse Constraint: | |
| - Idea: Set K = pinv(Q).T to force A = Identity or Projection. | |
| - Result: Too rigid. Collapsed the manifold to undirected graphs (Symmetric). | |
| Killed the ability to model asymmetric relationships (A->B != B->A). | |
| Ironically, still works, because of masking. | |
| -------------------------------------------------------------------------------- | |
| 3. INSIGHTS (The Pivot to Calculus) | |
| -------------------------------------------------------------------------------- | |
| The breakthrough came when we moved from Linear Algebra (Static) to Calculus (Dynamic). | |
| Standard Attention treats tokens as points in space (State vs. State). | |
| To encode "Time," we need to treat tokens as a signal flow. | |
| - Differentiation (Innovation): The "Change" at step t. | |
| - Integration (Accumulation): The "History" up to step t. | |
| Hypothesis: Causality is simply the alignment of an Accumulation with its own | |
| constituents. Future accumulations do not contain current innovations. | |
| -------------------------------------------------------------------------------- | |
| 4. AHA MOMENTS (The "Integral-Differential" Manifold) | |
| -------------------------------------------------------------------------------- | |
| We constructed a manifold where: | |
| - Q (Query) represents one aspect of the signal (Derivative/State). | |
| - K (Key) represents the other (Integral/History). | |
| Initial Attempt: Q = Innovation, K = Accumulation. | |
| - Logic: "Does this new change align with the history?" | |
| - Result: Visually interesting, but mathematically flawed. | |
| -------------------------------------------------------------------------------- | |
| 5. "OH, THAT'S UNEXPECTED" (The Washer Effect) | |
| -------------------------------------------------------------------------------- | |
| When we visualized Q=Innovation vs K=Accumulation, we saw: | |
| 1. Horizontal Stripes: The "Accumulated Key" became a massive mean vector, | |
| washing out local details. | |
| 2. Ratio > 1.0: The Future was "louder" than the Past. | |
| - Why? Random walks grow with sqrt(t). Future keys were physically longer | |
| vectors. The dot product is magnitude-sensitive. | |
| - The manifold was naturally "Acausal" (preferring the future). | |
| -------------------------------------------------------------------------------- | |
| 6. INSIGHT (The Inversion) | |
| -------------------------------------------------------------------------------- | |
| To fix the magnitude/loudness issue, we needed Leaky Integration (decay). | |
| To fix the causality, we inverted the roles. | |
| New Definition: | |
| - Q (Context): The Leaky Integrator. "I am the sum of my history." | |
| - K (Pulse): The Raw Innovation. "I am a specific event." | |
| Why this works: | |
| - If j < t (Past): Q_t physically *contains* vector K_j inside its sum. | |
| Dot Product = ||K_j||^2 + Noise. (Resonance). | |
| - If j > t (Future): Q_t does *not* contain K_j. | |
| Dot Product = Random Noise. (Orthogonality). | |
| -------------------------------------------------------------------------------- | |
| 7. AHA (The Resulting Manifold) | |
| -------------------------------------------------------------------------------- | |
| Visualizing this "Context-Pulse" manifold showed the desired property: | |
| - The Upper Triangle (Future) wasn't just masked; it was *uncorrelated noise*. | |
| - The Lower Triangle (Past) was *signal*. | |
| - The "Mask" is no longer destroying information; it is simply filtering noise. | |
| -------------------------------------------------------------------------------- | |
| 8. RESULT (The Manifold Battle) | |
| -------------------------------------------------------------------------------- | |
| We pitted this "Context-Pulse Attention" against Standard Attention on an | |
| Associative Recall task (A...B...A -> Predict B). | |
| - Standard Attention: Flatlined for 4000 steps. It struggled to fight the | |
| entropy of the softmax to find the "needle" in the past. | |
| - Context-Pulse Attention: Immediate convergence. | |
| Because Q *contains* the answer K geometrically, the gradient flow is | |
| instantaneous. The inductive bias aligns perfectly with the task. | |
| Final Status: | |
| We have a theoretical winner that beats Standard Attention on recall speed and strength | |
| by embedding the arrow of time directly into the vector geometry. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # --- 1. The Dataset: Associative Recall --- | |
| def get_batch(batch_size=32, seq_len=64, vocab_size=128, device='cpu'): | |
| """ | |
| Generates data where tokens repeat. | |
| Task: Given '... A ...', predict what followed 'A' last time. | |
| """ | |
| data = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) | |
| # Force repeat: Copy the first half to the second half (roughly) to ensure recall is possible | |
| half = seq_len // 2 | |
| data[:, half:] = data[:, :half] | |
| # Shift targets by 1 for next-token prediction | |
| inputs = data[:, :-1] | |
| targets = data[:, 1:] | |
| return inputs, targets | |
| # --- 2. The Baseline: Standard Attention --- | |
| class StandardAttention(nn.Module): | |
| def __init__(self, d_model, head_dim): | |
| super().__init__() | |
| self.W_Q = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_K = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_V = nn.Linear(d_model, head_dim, bias=False) | |
| self.scale = head_dim ** -0.5 | |
| def forward(self, x, mask=True): | |
| B, T, C = x.shape | |
| Q = self.W_Q(x) | |
| K = self.W_K(x) | |
| V = self.W_V(x) | |
| # Standard Dot Product | |
| # Shape: (B, T, T) | |
| Attn = (Q @ K.transpose(-2, -1)) * self.scale | |
| if mask: | |
| m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| Attn.masked_fill_(m, float('-inf')) | |
| return F.softmax(Attn, dim=-1) @ V | |
| # --- 3. The Challenger: Inverted Manifold (Context-Pulse) --- | |
| class ContextPulseAttention(nn.Module): | |
| def __init__(self, d_model, head_dim, decay=0.9): | |
| super().__init__() | |
| self.W_Q = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_K = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_V = nn.Linear(d_model, head_dim, bias=False) | |
| self.scale = head_dim ** -0.5 | |
| # We make decay a buffer or learnable param | |
| # For this test, we fix it to verify the structure works | |
| self.decay = decay | |
| def forward(self, x, mask=True): | |
| B, T, C = x.shape | |
| # 1. Project | |
| q_raw = self.W_Q(x) # These are essentially inputs/innovations | |
| k_raw = self.W_K(x) # These are inputs/innovations | |
| V = self.W_V(x) | |
| # 2. TRANSFORM THE MANIFOLD | |
| # Q becomes Context (Leaky Integration of inputs) | |
| # We apply the leaky scan on q_raw | |
| # Manual Leaky Scan (Not efficient for training, but correct for logic) | |
| # In production, use pscan or associative scan | |
| Q_context = torch.zeros_like(q_raw) | |
| running_q = torch.zeros(B, q_raw.shape[-1], device=x.device) | |
| for t in range(T): | |
| # Q[t] = alpha * Q[t-1] + (1-alpha) * input[t] | |
| running_q = self.decay * running_q + (1.0 - self.decay) * q_raw[:, t, :] | |
| Q_context[:, t, :] = running_q | |
| # K remains Pulse (Innovation) | |
| # We don't touch K. It represents the "Event". | |
| # 3. Compute Attention | |
| # "Does my Context contain this Pulse?" | |
| # Note: We still scale and mask, but the masking destroys less info now. | |
| Attn = (Q_context @ k_raw.transpose(-2, -1)) * self.scale | |
| if mask: | |
| m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| Attn.masked_fill_(m, float('-inf')) | |
| return F.softmax(Attn, dim=-1) @ V | |
| # --- 4. The Test Harness --- | |
| class ToyTransformer(nn.Module): | |
| def __init__(self, vocab_size, d_model, attn_type='standard'): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, d_model) | |
| if attn_type == 'standard': | |
| self.attn = StandardAttention(d_model, d_model) | |
| else: | |
| self.attn = ContextPulseAttention(d_model, d_model, decay=0.9) | |
| self.fc = nn.Linear(d_model, vocab_size) | |
| def forward(self, x): | |
| h = self.embed(x) | |
| h = self.attn(h) | |
| logits = self.fc(h) | |
| return logits | |
| def train_and_compare(): | |
| print("--- Starting Manifold Battle ---") | |
| # Params | |
| VOCAB = 64 | |
| DIM = 32 | |
| STEPS = 4000 | |
| LR = 3e-3 | |
| # Models | |
| model_std = ToyTransformer(VOCAB, DIM, 'standard') | |
| model_new = ToyTransformer(VOCAB, DIM, 'context_pulse') | |
| opt_std = optim.AdamW(model_std.parameters(), lr=LR) | |
| opt_new = optim.AdamW(model_new.parameters(), lr=LR) | |
| losses_std = [] | |
| losses_new = [] | |
| for i in range(STEPS): | |
| x, y = get_batch(seq_len=32, vocab_size=VOCAB) | |
| # Train Standard | |
| opt_std.zero_grad() | |
| logits = model_std(x) | |
| loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1)) | |
| loss.backward() | |
| opt_std.step() | |
| losses_std.append(loss.item()) | |
| # Train New | |
| opt_new.zero_grad() | |
| logits = model_new(x) | |
| loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1)) | |
| loss.backward() | |
| opt_new.step() | |
| losses_new.append(loss.item()) | |
| if i % 50 == 0: | |
| print(f"Step {i}: Std Loss {losses_std[-1]:.3f} | New Loss {losses_new[-1]:.3f}") | |
| # Plot | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(losses_std, label='Standard Attention', alpha=0.7) | |
| plt.plot(losses_new, label='Context-Pulse Attention', linewidth=2) | |
| plt.title("Learning Speed: Isotropic vs Causal Manifold") | |
| plt.xlabel("Training Steps") | |
| plt.ylabel("Loss") | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.show() | |
| train_and_compare() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math
--- 1. Primitives (Locked) ---
class RoPE(nn.Module):
def init(self, dim, max_len=2048):
super().init()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
def init(self, dim, heads):
super().init()
self.n_head = heads
self.head_dim = dim // heads
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
S = self.A - self.A.transpose(-1, -2)
S_broadcast = S.unsqueeze(0)
flow = torch.matmul(x, S_broadcast)
return x + flow
class ConvexSoftmax(nn.Module):
def forward(self, scores):
m, _ = scores.max(dim=-1, keepdim=True)
y = scores - m
ex = y.exp()
lse = m + ex.sum(dim=-1, keepdim=True).log()
return torch.exp(scores - lse)
def build_alpert_basis(block_size, poly_order=1, device='cpu'):
t = torch.linspace(-1, 1, block_size, device=device)
p0 = torch.ones_like(t)
basis_list = [p0]
if poly_order >= 1: basis_list.append(t)
if poly_order >= 2: basis_list.append(3 * t**2 - 1)
W = torch.stack(basis_list, dim=1)
W = F.normalize(W, p=2, dim=0)
return W
--- 2. Resonance-Gated Alpert Model ---
class ContextPulseAttention(nn.Module):
def init(self, d_model, n_head, chunk_size=16, poly_order=1):
super().init()
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.chunk_size = chunk_size
self.poly_order = poly_order
self.scale = self.head_dim ** -0.5
--- Harness ---
class ToyTransformer(nn.Module):
def init(self, vocab_size, d_model, n_head):
super().init()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)
def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
START_TOKEN = 0
STOP_TOKEN = 1
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
def run_resonant_jump():
print("--- ALPERT + RESONANCE-DRIVEN JUMP: 4 Heads, 64 Dim ---")
VOCAB = 64
DIM = 128
HEADS = 4
STEPS = 1500
SEQ_LEN = 256
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
if name == "main":
run_resonant_jump()