Skip to content

Instantly share code, notes, and snippets.

@falseywinchnet
Created November 24, 2025 07:07
Show Gist options
  • Select an option

  • Save falseywinchnet/26cb6597cdab205d1e88702faf8f49aa to your computer and use it in GitHub Desktop.

Select an option

Save falseywinchnet/26cb6597cdab205d1e88702faf8f49aa to your computer and use it in GitHub Desktop.
"""
THE CONTEXT-PULSE MANIFOLD:
Deriving an Inherently Autoregressive Attention Mechanism
A Gemini Collaborative Development
--------------------------------------------------------------------------------
1. THE WHY (Intuition & Motivation)
--------------------------------------------------------------------------------
falseywinchnet approached with a fundamental dissatisfaction regarding Standard Attention:
it relies on computing an "All-to-All" energy matrix (Riemannian metric) only to
destroy half of it with a "crude" boolean mask to enforce causality.
This "crude" masking does more than just hide the future; it fundamentally
destabilizes the energy landscape of the attention head.
Consider the Partition Function (Z), the denominator of the Softmax:
- In an unmasked (All-to-All) setting, Z sums over N items. The energy
distribution is consistent across all positions.
- Under causal masking, the domain of Z varies wildly. At t=1, Z sums over
1 item. At t=1000, Z sums over 1000.
This violates the Parseval intuition of energy conservation. The model is
forced to compare logits against a shifting baseline variance. When the
mask arbitrarily removes high-similarity future tokens, or when the
limited past offers no "resonance," the Softmax function still demands
that probability mass sum to exactly 1.0.
Where does this "orphaned" energy go? It creates the "Attention Sink"
phenomenon. The model learns to dump excess probability mass onto the
start-of-sequence (BOS) token—not because it is semantically relevant,
but because it acts as a numerical capacitor to absorb the variance
caused by the mask's truncation of the manifold. We are effectively
forcing the model to waste capacity learning garbage collection
mechanisms instead of pure semantic flow.
The intuition: This feels like a "cookie crumble" approach.
Studying attention, we learned it contains Parseval-adjacent properties and mathematical aspects.
Understanding attention's manifold properties and adjusting them intelligently was our goal.
Can we instead design a mathematical manifold that *inherently* respects the
arrow of time? We looked toward Parseval's Theorem and Unitary operators,
seeking a transform where the "Products" or "Weights" naturally decay or
orthogonalize in the future, rendering the mask a formality rather than a
structural guillotine.
--------------------------------------------------------------------------------
2. THE WHAT (The Tangential Exploration)
--------------------------------------------------------------------------------
We explored several linear algebraic constraints to force this behavior:
A. The Pseudoinverse Constraint:
- Idea: Set K = pinv(Q).T to force A = Identity or Projection.
- Result: Too rigid. Collapsed the manifold to undirected graphs (Symmetric).
Killed the ability to model asymmetric relationships (A->B != B->A).
Ironically, still works, because of masking.
--------------------------------------------------------------------------------
3. INSIGHTS (The Pivot to Calculus)
--------------------------------------------------------------------------------
The breakthrough came when we moved from Linear Algebra (Static) to Calculus (Dynamic).
Standard Attention treats tokens as points in space (State vs. State).
To encode "Time," we need to treat tokens as a signal flow.
- Differentiation (Innovation): The "Change" at step t.
- Integration (Accumulation): The "History" up to step t.
Hypothesis: Causality is simply the alignment of an Accumulation with its own
constituents. Future accumulations do not contain current innovations.
--------------------------------------------------------------------------------
4. AHA MOMENTS (The "Integral-Differential" Manifold)
--------------------------------------------------------------------------------
We constructed a manifold where:
- Q (Query) represents one aspect of the signal (Derivative/State).
- K (Key) represents the other (Integral/History).
Initial Attempt: Q = Innovation, K = Accumulation.
- Logic: "Does this new change align with the history?"
- Result: Visually interesting, but mathematically flawed.
--------------------------------------------------------------------------------
5. "OH, THAT'S UNEXPECTED" (The Washer Effect)
--------------------------------------------------------------------------------
When we visualized Q=Innovation vs K=Accumulation, we saw:
1. Horizontal Stripes: The "Accumulated Key" became a massive mean vector,
washing out local details.
2. Ratio > 1.0: The Future was "louder" than the Past.
- Why? Random walks grow with sqrt(t). Future keys were physically longer
vectors. The dot product is magnitude-sensitive.
- The manifold was naturally "Acausal" (preferring the future).
--------------------------------------------------------------------------------
6. INSIGHT (The Inversion)
--------------------------------------------------------------------------------
To fix the magnitude/loudness issue, we needed Leaky Integration (decay).
To fix the causality, we inverted the roles.
New Definition:
- Q (Context): The Leaky Integrator. "I am the sum of my history."
- K (Pulse): The Raw Innovation. "I am a specific event."
Why this works:
- If j < t (Past): Q_t physically *contains* vector K_j inside its sum.
Dot Product = ||K_j||^2 + Noise. (Resonance).
- If j > t (Future): Q_t does *not* contain K_j.
Dot Product = Random Noise. (Orthogonality).
--------------------------------------------------------------------------------
7. AHA (The Resulting Manifold)
--------------------------------------------------------------------------------
Visualizing this "Context-Pulse" manifold showed the desired property:
- The Upper Triangle (Future) wasn't just masked; it was *uncorrelated noise*.
- The Lower Triangle (Past) was *signal*.
- The "Mask" is no longer destroying information; it is simply filtering noise.
--------------------------------------------------------------------------------
8. RESULT (The Manifold Battle)
--------------------------------------------------------------------------------
We pitted this "Context-Pulse Attention" against Standard Attention on an
Associative Recall task (A...B...A -> Predict B).
- Standard Attention: Flatlined for 4000 steps. It struggled to fight the
entropy of the softmax to find the "needle" in the past.
- Context-Pulse Attention: Immediate convergence.
Because Q *contains* the answer K geometrically, the gradient flow is
instantaneous. The inductive bias aligns perfectly with the task.
Final Status:
We have a theoretical winner that beats Standard Attention on recall speed and strength
by embedding the arrow of time directly into the vector geometry.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
# --- 1. The Dataset: Associative Recall ---
def get_batch(batch_size=32, seq_len=64, vocab_size=128, device='cpu'):
"""
Generates data where tokens repeat.
Task: Given '... A ...', predict what followed 'A' last time.
"""
data = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
# Force repeat: Copy the first half to the second half (roughly) to ensure recall is possible
half = seq_len // 2
data[:, half:] = data[:, :half]
# Shift targets by 1 for next-token prediction
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
# --- 2. The Baseline: Standard Attention ---
class StandardAttention(nn.Module):
def __init__(self, d_model, head_dim):
super().__init__()
self.W_Q = nn.Linear(d_model, head_dim, bias=False)
self.W_K = nn.Linear(d_model, head_dim, bias=False)
self.W_V = nn.Linear(d_model, head_dim, bias=False)
self.scale = head_dim ** -0.5
def forward(self, x, mask=True):
B, T, C = x.shape
Q = self.W_Q(x)
K = self.W_K(x)
V = self.W_V(x)
# Standard Dot Product
# Shape: (B, T, T)
Attn = (Q @ K.transpose(-2, -1)) * self.scale
if mask:
m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(m, float('-inf'))
return F.softmax(Attn, dim=-1) @ V
# --- 3. The Challenger: Inverted Manifold (Context-Pulse) ---
class ContextPulseAttention(nn.Module):
def __init__(self, d_model, head_dim, decay=0.9):
super().__init__()
self.W_Q = nn.Linear(d_model, head_dim, bias=False)
self.W_K = nn.Linear(d_model, head_dim, bias=False)
self.W_V = nn.Linear(d_model, head_dim, bias=False)
self.scale = head_dim ** -0.5
# We make decay a buffer or learnable param
# For this test, we fix it to verify the structure works
self.decay = decay
def forward(self, x, mask=True):
B, T, C = x.shape
# 1. Project
q_raw = self.W_Q(x) # These are essentially inputs/innovations
k_raw = self.W_K(x) # These are inputs/innovations
V = self.W_V(x)
# 2. TRANSFORM THE MANIFOLD
# Q becomes Context (Leaky Integration of inputs)
# We apply the leaky scan on q_raw
# Manual Leaky Scan (Not efficient for training, but correct for logic)
# In production, use pscan or associative scan
Q_context = torch.zeros_like(q_raw)
running_q = torch.zeros(B, q_raw.shape[-1], device=x.device)
for t in range(T):
# Q[t] = alpha * Q[t-1] + (1-alpha) * input[t]
running_q = self.decay * running_q + (1.0 - self.decay) * q_raw[:, t, :]
Q_context[:, t, :] = running_q
# K remains Pulse (Innovation)
# We don't touch K. It represents the "Event".
# 3. Compute Attention
# "Does my Context contain this Pulse?"
# Note: We still scale and mask, but the masking destroys less info now.
Attn = (Q_context @ k_raw.transpose(-2, -1)) * self.scale
if mask:
m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(m, float('-inf'))
return F.softmax(Attn, dim=-1) @ V
# --- 4. The Test Harness ---
class ToyTransformer(nn.Module):
def __init__(self, vocab_size, d_model, attn_type='standard'):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
if attn_type == 'standard':
self.attn = StandardAttention(d_model, d_model)
else:
self.attn = ContextPulseAttention(d_model, d_model, decay=0.9)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
logits = self.fc(h)
return logits
def train_and_compare():
print("--- Starting Manifold Battle ---")
# Params
VOCAB = 64
DIM = 32
STEPS = 4000
LR = 3e-3
# Models
model_std = ToyTransformer(VOCAB, DIM, 'standard')
model_new = ToyTransformer(VOCAB, DIM, 'context_pulse')
opt_std = optim.AdamW(model_std.parameters(), lr=LR)
opt_new = optim.AdamW(model_new.parameters(), lr=LR)
losses_std = []
losses_new = []
for i in range(STEPS):
x, y = get_batch(seq_len=32, vocab_size=VOCAB)
# Train Standard
opt_std.zero_grad()
logits = model_std(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt_std.step()
losses_std.append(loss.item())
# Train New
opt_new.zero_grad()
logits = model_new(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt_new.step()
losses_new.append(loss.item())
if i % 50 == 0:
print(f"Step {i}: Std Loss {losses_std[-1]:.3f} | New Loss {losses_new[-1]:.3f}")
# Plot
plt.figure(figsize=(10, 5))
plt.plot(losses_std, label='Standard Attention', alpha=0.7)
plt.plot(losses_new, label='Context-Pulse Attention', linewidth=2)
plt.title("Learning Speed: Isotropic vs Causal Manifold")
plt.xlabel("Training Steps")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
train_and_compare()
@falseywinchnet
Copy link
Author

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import matplotlib.pyplot as plt

# --- 1. Locked Primitives ---
class RoPE(nn.Module):
    def __init__(self, dim, max_len=2048):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_len).float()
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        self.register_buffer('cos', freqs.cos())
        self.register_buffer('sin', freqs.sin())
    def forward(self, x):
        seq_len = x.shape[2]
        cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
        sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
        x1 = x[..., 0::2]
        x2 = x[..., 1::2]
        return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)

class WedgeTransform(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.n_head = heads
        self.head_dim = dim // heads 
        self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
    def forward(self, x):
        B, H, T, Dh = x.shape
        v = x.transpose(1, 2)
        S = self.A - self.A.transpose(-1, -2)
        flow = x + torch.matmul(x, S)
        return flow

class ConvexSoftmax(nn.Module):
    def forward(self, scores):
        m, _ = scores.max(dim=-1, keepdim=True)
        y = scores - m
        ex = y.exp()
        lse = m + ex.sum(dim=-1, keepdim=True).log()
        return torch.exp(scores - lse)

# --- 2. Multi-Head Causal Chunked Model ---
class ContextPulseAttention(nn.Module):
    def __init__(self, d_model, n_head, chunk_size=16):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.head_dim = d_model // n_head
        self.chunk_size = chunk_size
        self.scale = self.head_dim ** -0.5

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        
        nn.init.orthogonal_(self.W_Q.weight)
        nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
        nn.init.xavier_normal_(self.W_V.weight)

        self.gate_proj = nn.Linear(d_model, d_model)
        self.rope = RoPE(self.head_dim)
        
        # FIX: Pass d_model, not head_dim, because Wedge divides internally
        self.wedge = WedgeTransform(d_model, n_head)
        
        self.softmax = ConvexSoftmax()
        
        self.sink_scalar = nn.Parameter(torch.zeros(1, n_head, 1, 1))
        self.v_null = nn.Parameter(torch.zeros(1, n_head, 1, self.head_dim)) 

    def forward(self, x):
        B, T, C = x.shape
        H, Dh = self.n_head, self.head_dim
        
        q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
        k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
        v_val = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
        gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)

        num_chunks = math.ceil(T / self.chunk_size)
        Q_output_list = []
        history_cache = [] 
        
        for i in range(num_chunks):
            t_start = i * self.chunk_size
            t_end = min((i + 1) * self.chunk_size, T)
            
            q_chunk = q_raw[:, :, t_start:t_end, :]
            k_chunk = k_raw[:, :, t_start:t_end, :]
            gate_chunk = gate[:, :, t_start:t_end, :]
            
            if len(history_cache) > 0:
                history_stack = torch.cat(history_cache, dim=2)
                resonance_scores = (k_chunk @ history_stack.transpose(-2, -1)) * self.scale
                resonance_weights = F.softmax(resonance_scores, dim=-1)
                q_retrieved = resonance_weights @ history_stack
                base_state = history_cache[-1].expand(-1, -1, k_chunk.size(2), -1)
                q_initial = base_state + q_retrieved
            else:
                q_initial = torch.zeros_like(k_chunk)

            masked_input = gate_chunk * q_chunk
            q_integrated = torch.cumsum(masked_input, dim=2)
            q_chunk_out = q_integrated + q_initial
            Q_output_list.append(q_chunk_out)
            final_state = q_chunk_out[:, :, -1:, :]
            history_cache.append(final_state)
            
        Q_context = torch.cat(Q_output_list, dim=2)
        K_pulse = k_raw
        
        # Pointwise Symbiotic Norm
        mean_q = Q_context.mean(dim=-1, keepdim=True)
        std_q = Q_context.std(dim=-1, keepdim=True)
        mean_k = K_pulse.mean(dim=-1, keepdim=True)
        std_k = K_pulse.std(dim=-1, keepdim=True)
        mean_sym = 0.5 * (mean_q + mean_k)
        std_sym = 0.5 * (std_q + std_k)
        Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
        K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
            
        # Geometry
        q_roped = self.rope(Q_context) # (B, H, T, D)
        k_roped = self.rope(K_pulse)
        
        # TRANSPOSE SANDWICH for Wedge compatibility
        # Input: (B, T, H, D) -> Wedge flips to (B, H, T, D) for matmul -> Returns (B, T, H, D)
        Q_geo = self.wedge(q_roped)
        K_geo = self.wedge(k_roped)
        
        Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        Attn.masked_fill_(mask, float('-inf'))
        
        null_scores = self.sink_scalar.expand(B, H, T, 1)
        Attn_full = torch.cat([Attn, null_scores], dim=-1)
        probs_full = self.softmax(Attn_full)
        
        out = probs_full[..., :T] @ v_val + probs_full[..., T:] * self.v_null
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_O(out)

# --- 3. Multi-Head Harness ---
class ToyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_head):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
        self.fc = nn.Linear(d_model, vocab_size)
    def forward(self, x):
        h = self.embed(x)
        h = self.attn(h)
        return self.fc(h)

def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
    START_TOKEN = 0
    STOP_TOKEN = 1
    data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
    pattern_len = 16
    for b in range(batch_size):
        pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
        data[b, 0] = START_TOKEN
        data[b, 1:1+pattern_len] = pattern
        data[b, 1+pattern_len] = STOP_TOKEN
        trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
        data[b, trigger_start] = START_TOKEN
        data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
    inputs = data[:, :-1]
    targets = data[:, 1:]
    return inputs, targets

def run_multihead_hard():
    print("--- FIXED MULTI-HEAD GAP COPY: 4 Heads, 64 Dim ---")
    VOCAB = 64
    DIM = 64
    HEADS = 4
    STEPS = 1500
    SEQ_LEN = 256
    LR = 3e-3
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    model = ToyTransformer(VOCAB, DIM, HEADS).to(DEVICE)
    opt = optim.AdamW(model.parameters(), lr=LR)
    
    loss_hist = []
    
    for i in range(STEPS):
        x, y = get_gap_copy_batch(seq_len=SEQ_LEN, vocab_size=VOCAB, device=DEVICE)
        opt.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
        loss.backward()
        opt.step()
        loss_hist.append(loss.item())
        
        if i % 100 == 0:
            print(f"Step {i}: Loss {loss.item():.4f}")

    plt.figure(figsize=(10, 6))
    plt.plot(loss_hist)
    plt.title("Multi-Head Gap Copy (Fixed)")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.grid(True, alpha=0.3)
    plt.show()

if __name__ == "__main__":
    run_multihead_hard()
    ```

@falseywinchnet
Copy link
Author

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math

--- 1. Primitives (Locked) ---

class RoPE(nn.Module):
def init(self, dim, max_len=2048):
super().init()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)

class WedgeTransform(nn.Module):
def init(self, dim, heads):
super().init()
self.n_head = heads
self.head_dim = dim // heads
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
S = self.A - self.A.transpose(-1, -2)
S_broadcast = S.unsqueeze(0)
flow = torch.matmul(x, S_broadcast)
return x + flow

class ConvexSoftmax(nn.Module):
def forward(self, scores):
m, _ = scores.max(dim=-1, keepdim=True)
y = scores - m
ex = y.exp()
lse = m + ex.sum(dim=-1, keepdim=True).log()
return torch.exp(scores - lse)

def build_alpert_basis(block_size, poly_order=1, device='cpu'):
t = torch.linspace(-1, 1, block_size, device=device)
p0 = torch.ones_like(t)
basis_list = [p0]
if poly_order >= 1: basis_list.append(t)
if poly_order >= 2: basis_list.append(3 * t**2 - 1)
W = torch.stack(basis_list, dim=1)
W = F.normalize(W, p=2, dim=0)
return W

--- 2. Resonance-Gated Alpert Model ---

class ContextPulseAttention(nn.Module):
def init(self, d_model, n_head, chunk_size=16, poly_order=1):
super().init()
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.chunk_size = chunk_size
self.poly_order = poly_order
self.scale = self.head_dim ** -0.5

    self.W_Q = nn.Linear(d_model, d_model, bias=False)
    self.W_K = nn.Linear(d_model, d_model, bias=False)
    self.W_V = nn.Linear(d_model, d_model, bias=False)
    self.W_O = nn.Linear(d_model, d_model, bias=False)
    
    nn.init.orthogonal_(self.W_Q.weight)
    nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
    nn.init.xavier_normal_(self.W_V.weight)

    self.gate_proj = nn.Linear(d_model, d_model)
    
    # REMOVED: jump_proj (Blind Jump)
    # ADDED: Resonance Sensitivity (Smart Jump)
    self.jump_scale = nn.Parameter(torch.tensor(1.0))
    self.jump_bias = nn.Parameter(torch.tensor(0.0)) # Start neutral/conservative
    
    self.rope = RoPE(self.head_dim)
    self.wedge = WedgeTransform(d_model, n_head)
    self.softmax = ConvexSoftmax()
    
    self.sink_scalar = nn.Parameter(torch.zeros(1, n_head, 1, 1))
    self.v_null = nn.Parameter(torch.zeros(1, n_head, 1, self.head_dim)) 
    
    basis = build_alpert_basis(chunk_size, poly_order)
    self.register_buffer('alpert_basis', basis)

def forward(self, x):
    B, T, C = x.shape
    H, Dh = self.n_head, self.head_dim
    
    q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
    k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
    v_val = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
   
    gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)

    num_chunks = math.ceil(T / self.chunk_size)
    Q_output_list = []
    history_coeffs_cache = [] 
    last_chunk  =torch.zeros_like(q_raw[:, :, 0:self.chunk_size, :])

    W_basis = self.alpert_basis 
    for i in range(num_chunks):
        t_start = i * self.chunk_size
        t_end = min((i + 1) * self.chunk_size, T)
        
        q_chunk = q_raw[:, :, t_start:t_end, :]
        k_chunk = k_raw[:, :, t_start:t_end, :]
        gate_chunk = gate[:, :, t_start:t_end, :]
        
        current_L = q_chunk.size(2)
        if current_L < self.chunk_size:
            W_curr = build_alpert_basis(current_L, self.poly_order, device=x.device)
        else:
            W_curr = W_basis
        
        # 1. Encode Current K (Moments)
        k_coeffs = torch.einsum('bhld, lk -> bhkd', k_chunk, W_curr)
        
        # 2. Wave Retrieval
        if len(history_coeffs_cache) > 0:
            history_stack = torch.cat(history_coeffs_cache, dim=2)
            
            # Resonance (Score)
            resonance_scores = torch.einsum('bhkd, bhnkd -> bhn', k_coeffs, history_stack) * self.scale
            resonance_scores = resonance_scores.unsqueeze(-2) # (B, H, 1, N)
            
            # Weights
            resonance_weights = F.softmax(resonance_scores, dim=-1)
            
            # Retrieve Moments
            coeffs_retrieved = torch.einsum('bhn, bhnkd -> bhkd', resonance_weights.squeeze(2), history_stack)
            
            # Reconstruct Trajectory
            q_reconstructed = torch.einsum('bhkd, lk -> bhld', coeffs_retrieved, W_curr)
            
            # Additive Injection
            if 'last_state' in locals():
                continuity_stream = last_state.expand(-1, -1, current_L, -1)
                q_injected = continuity_stream + q_reconstructed
            else:
                q_injected = q_reconstructed
        else:
            q_injected = torch.zeros_like(q_chunk)

        # 3. Integration
        q_integrated = torch.cumsum(q_chunk, dim=2)
        q_chunk_out = q_integrated + q_injected - last_chunk
        
        Q_output_list.append(q_chunk_out)
        
        # 4. Update Cache
        q_out_coeffs = torch.einsum('bhld, lk -> bhkd', q_chunk_out, W_curr)
        history_coeffs_cache.append(q_out_coeffs.unsqueeze(2))
        last_state = q_chunk_out[:, :, -1:, :]
        last_chunk = q_injected[:, :, -1:, :]
    Q_context = torch.cat(Q_output_list, dim=2)
    K_pulse = k_raw
    
    # Pointwise Norm
    mean_q = Q_context.mean(dim=-1, keepdim=True)
    std_q = Q_context.std(dim=-1, keepdim=True)
    mean_k = K_pulse.mean(dim=-1, keepdim=True)
    std_k = K_pulse.std(dim=-1, keepdim=True)
    mean_sym = 0.5 * (mean_q + mean_k)
    std_sym = 0.5 * (std_q + std_k)
    Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
    K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
        
    # Geometry
    Q_geo = self.wedge(Q_context)
    K_geo = self.wedge(K_pulse)

    Q_geo = self.rope(Q_geo)
    K_geo = self.rope(K_geo)

    
    Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
    mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
    Attn.masked_fill_(mask, float('-inf'))
    
    null_scores = self.sink_scalar.expand(B, H, T, 1)
    Attn_full = torch.cat([Attn, null_scores], dim=-1)
    probs_full = self.softmax(Attn_full)
    
    out = probs_full[..., :T] @ v_val + probs_full[..., T:] * self.v_null
    out = out.transpose(1, 2).contiguous().view(B, T, C)
    return self.W_O(out)

--- Harness ---

class ToyTransformer(nn.Module):
def init(self, vocab_size, d_model, n_head):
super().init()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)

def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
START_TOKEN = 0
STOP_TOKEN = 1
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets

def run_resonant_jump():
print("--- ALPERT + RESONANCE-DRIVEN JUMP: 4 Heads, 64 Dim ---")
VOCAB = 64
DIM = 128
HEADS = 4
STEPS = 1500
SEQ_LEN = 256
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = ToyTransformer(VOCAB, DIM, HEADS).to(DEVICE)
opt = optim.AdamW(model.parameters(), lr=LR)

loss_hist = []

for i in range(STEPS):
    x, y = get_gap_copy_batch(seq_len=SEQ_LEN, vocab_size=VOCAB, device=DEVICE)
    opt.zero_grad()
    logits = model(x)
    loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
    loss.backward()
    opt.step()
    loss_hist.append(loss.item())
    
    if i % 100 == 0:
        print(f"Step {i}: Loss {loss.item():.4f}")

plt.figure(figsize=(10, 6))
plt.plot(loss_hist)
plt.title("Resonance-Gated Alpert Retrieval")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.show()

if name == "main":
run_resonant_jump()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment