-
-
Save falseywinchnet/26cb6597cdab205d1e88702faf8f49aa to your computer and use it in GitHub Desktop.
| """ | |
| THE CONTEXT-PULSE MANIFOLD: | |
| Deriving an Inherently Autoregressive Attention Mechanism | |
| A Gemini Collaborative Development | |
| -------------------------------------------------------------------------------- | |
| 1. THE WHY (Intuition & Motivation) | |
| -------------------------------------------------------------------------------- | |
| falseywinchnet approached with a fundamental dissatisfaction regarding Standard Attention: | |
| it relies on computing an "All-to-All" energy matrix (Riemannian metric) only to | |
| destroy half of it with a "crude" boolean mask to enforce causality. | |
| This "crude" masking does more than just hide the future; it fundamentally | |
| destabilizes the energy landscape of the attention head. | |
| Consider the Partition Function (Z), the denominator of the Softmax: | |
| - In an unmasked (All-to-All) setting, Z sums over N items. The energy | |
| distribution is consistent across all positions. | |
| - Under causal masking, the domain of Z varies wildly. At t=1, Z sums over | |
| 1 item. At t=1000, Z sums over 1000. | |
| This violates the Parseval intuition of energy conservation. The model is | |
| forced to compare logits against a shifting baseline variance. When the | |
| mask arbitrarily removes high-similarity future tokens, or when the | |
| limited past offers no "resonance," the Softmax function still demands | |
| that probability mass sum to exactly 1.0. | |
| Where does this "orphaned" energy go? It creates the "Attention Sink" | |
| phenomenon. The model learns to dump excess probability mass onto the | |
| start-of-sequence (BOS) token—not because it is semantically relevant, | |
| but because it acts as a numerical capacitor to absorb the variance | |
| caused by the mask's truncation of the manifold. We are effectively | |
| forcing the model to waste capacity learning garbage collection | |
| mechanisms instead of pure semantic flow. | |
| The intuition: This feels like a "cookie crumble" approach. | |
| Studying attention, we learned it contains Parseval-adjacent properties and mathematical aspects. | |
| Understanding attention's manifold properties and adjusting them intelligently was our goal. | |
| Can we instead design a mathematical manifold that *inherently* respects the | |
| arrow of time? We looked toward Parseval's Theorem and Unitary operators, | |
| seeking a transform where the "Products" or "Weights" naturally decay or | |
| orthogonalize in the future, rendering the mask a formality rather than a | |
| structural guillotine. | |
| -------------------------------------------------------------------------------- | |
| 2. THE WHAT (The Tangential Exploration) | |
| -------------------------------------------------------------------------------- | |
| We explored several linear algebraic constraints to force this behavior: | |
| A. The Pseudoinverse Constraint: | |
| - Idea: Set K = pinv(Q).T to force A = Identity or Projection. | |
| - Result: Too rigid. Collapsed the manifold to undirected graphs (Symmetric). | |
| Killed the ability to model asymmetric relationships (A->B != B->A). | |
| Ironically, still works, because of masking. | |
| -------------------------------------------------------------------------------- | |
| 3. INSIGHTS (The Pivot to Calculus) | |
| -------------------------------------------------------------------------------- | |
| The breakthrough came when we moved from Linear Algebra (Static) to Calculus (Dynamic). | |
| Standard Attention treats tokens as points in space (State vs. State). | |
| To encode "Time," we need to treat tokens as a signal flow. | |
| - Differentiation (Innovation): The "Change" at step t. | |
| - Integration (Accumulation): The "History" up to step t. | |
| Hypothesis: Causality is simply the alignment of an Accumulation with its own | |
| constituents. Future accumulations do not contain current innovations. | |
| -------------------------------------------------------------------------------- | |
| 4. AHA MOMENTS (The "Integral-Differential" Manifold) | |
| -------------------------------------------------------------------------------- | |
| We constructed a manifold where: | |
| - Q (Query) represents one aspect of the signal (Derivative/State). | |
| - K (Key) represents the other (Integral/History). | |
| Initial Attempt: Q = Innovation, K = Accumulation. | |
| - Logic: "Does this new change align with the history?" | |
| - Result: Visually interesting, but mathematically flawed. | |
| -------------------------------------------------------------------------------- | |
| 5. "OH, THAT'S UNEXPECTED" (The Washer Effect) | |
| -------------------------------------------------------------------------------- | |
| When we visualized Q=Innovation vs K=Accumulation, we saw: | |
| 1. Horizontal Stripes: The "Accumulated Key" became a massive mean vector, | |
| washing out local details. | |
| 2. Ratio > 1.0: The Future was "louder" than the Past. | |
| - Why? Random walks grow with sqrt(t). Future keys were physically longer | |
| vectors. The dot product is magnitude-sensitive. | |
| - The manifold was naturally "Acausal" (preferring the future). | |
| -------------------------------------------------------------------------------- | |
| 6. INSIGHT (The Inversion) | |
| -------------------------------------------------------------------------------- | |
| To fix the magnitude/loudness issue, we needed Leaky Integration (decay). | |
| To fix the causality, we inverted the roles. | |
| New Definition: | |
| - Q (Context): The Leaky Integrator. "I am the sum of my history." | |
| - K (Pulse): The Raw Innovation. "I am a specific event." | |
| Why this works: | |
| - If j < t (Past): Q_t physically *contains* vector K_j inside its sum. | |
| Dot Product = ||K_j||^2 + Noise. (Resonance). | |
| - If j > t (Future): Q_t does *not* contain K_j. | |
| Dot Product = Random Noise. (Orthogonality). | |
| -------------------------------------------------------------------------------- | |
| 7. AHA (The Resulting Manifold) | |
| -------------------------------------------------------------------------------- | |
| Visualizing this "Context-Pulse" manifold showed the desired property: | |
| - The Upper Triangle (Future) wasn't just masked; it was *uncorrelated noise*. | |
| - The Lower Triangle (Past) was *signal*. | |
| - The "Mask" is no longer destroying information; it is simply filtering noise. | |
| -------------------------------------------------------------------------------- | |
| 8. RESULT (The Manifold Battle) | |
| -------------------------------------------------------------------------------- | |
| We pitted this "Context-Pulse Attention" against Standard Attention on an | |
| Associative Recall task (A...B...A -> Predict B). | |
| - Standard Attention: Flatlined for 4000 steps. It struggled to fight the | |
| entropy of the softmax to find the "needle" in the past. | |
| - Context-Pulse Attention: Immediate convergence. | |
| Because Q *contains* the answer K geometrically, the gradient flow is | |
| instantaneous. The inductive bias aligns perfectly with the task. | |
| Final Status: | |
| We have a theoretical winner that beats Standard Attention on recall speed and strength | |
| by embedding the arrow of time directly into the vector geometry. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # --- 1. The Dataset: Associative Recall --- | |
| def get_batch(batch_size=32, seq_len=64, vocab_size=128, device='cpu'): | |
| """ | |
| Generates data where tokens repeat. | |
| Task: Given '... A ...', predict what followed 'A' last time. | |
| """ | |
| data = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) | |
| # Force repeat: Copy the first half to the second half (roughly) to ensure recall is possible | |
| half = seq_len // 2 | |
| data[:, half:] = data[:, :half] | |
| # Shift targets by 1 for next-token prediction | |
| inputs = data[:, :-1] | |
| targets = data[:, 1:] | |
| return inputs, targets | |
| # --- 2. The Baseline: Standard Attention --- | |
| class StandardAttention(nn.Module): | |
| def __init__(self, d_model, head_dim): | |
| super().__init__() | |
| self.W_Q = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_K = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_V = nn.Linear(d_model, head_dim, bias=False) | |
| self.scale = head_dim ** -0.5 | |
| def forward(self, x, mask=True): | |
| B, T, C = x.shape | |
| Q = self.W_Q(x) | |
| K = self.W_K(x) | |
| V = self.W_V(x) | |
| # Standard Dot Product | |
| # Shape: (B, T, T) | |
| Attn = (Q @ K.transpose(-2, -1)) * self.scale | |
| if mask: | |
| m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| Attn.masked_fill_(m, float('-inf')) | |
| return F.softmax(Attn, dim=-1) @ V | |
| # --- 3. The Challenger: Inverted Manifold (Context-Pulse) --- | |
| class ContextPulseAttention(nn.Module): | |
| def __init__(self, d_model, head_dim, decay=0.9): | |
| super().__init__() | |
| self.W_Q = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_K = nn.Linear(d_model, head_dim, bias=False) | |
| self.W_V = nn.Linear(d_model, head_dim, bias=False) | |
| self.scale = head_dim ** -0.5 | |
| # We make decay a buffer or learnable param | |
| # For this test, we fix it to verify the structure works | |
| self.decay = decay | |
| def forward(self, x, mask=True): | |
| B, T, C = x.shape | |
| # 1. Project | |
| q_raw = self.W_Q(x) # These are essentially inputs/innovations | |
| k_raw = self.W_K(x) # These are inputs/innovations | |
| V = self.W_V(x) | |
| # 2. TRANSFORM THE MANIFOLD | |
| # Q becomes Context (Leaky Integration of inputs) | |
| # We apply the leaky scan on q_raw | |
| # Manual Leaky Scan (Not efficient for training, but correct for logic) | |
| # In production, use pscan or associative scan | |
| Q_context = torch.zeros_like(q_raw) | |
| running_q = torch.zeros(B, q_raw.shape[-1], device=x.device) | |
| for t in range(T): | |
| # Q[t] = alpha * Q[t-1] + (1-alpha) * input[t] | |
| running_q = self.decay * running_q + (1.0 - self.decay) * q_raw[:, t, :] | |
| Q_context[:, t, :] = running_q | |
| # K remains Pulse (Innovation) | |
| # We don't touch K. It represents the "Event". | |
| # 3. Compute Attention | |
| # "Does my Context contain this Pulse?" | |
| # Note: We still scale and mask, but the masking destroys less info now. | |
| Attn = (Q_context @ k_raw.transpose(-2, -1)) * self.scale | |
| if mask: | |
| m = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() | |
| Attn.masked_fill_(m, float('-inf')) | |
| return F.softmax(Attn, dim=-1) @ V | |
| # --- 4. The Test Harness --- | |
| class ToyTransformer(nn.Module): | |
| def __init__(self, vocab_size, d_model, attn_type='standard'): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, d_model) | |
| if attn_type == 'standard': | |
| self.attn = StandardAttention(d_model, d_model) | |
| else: | |
| self.attn = ContextPulseAttention(d_model, d_model, decay=0.9) | |
| self.fc = nn.Linear(d_model, vocab_size) | |
| def forward(self, x): | |
| h = self.embed(x) | |
| h = self.attn(h) | |
| logits = self.fc(h) | |
| return logits | |
| def train_and_compare(): | |
| print("--- Starting Manifold Battle ---") | |
| # Params | |
| VOCAB = 64 | |
| DIM = 32 | |
| STEPS = 4000 | |
| LR = 3e-3 | |
| # Models | |
| model_std = ToyTransformer(VOCAB, DIM, 'standard') | |
| model_new = ToyTransformer(VOCAB, DIM, 'context_pulse') | |
| opt_std = optim.AdamW(model_std.parameters(), lr=LR) | |
| opt_new = optim.AdamW(model_new.parameters(), lr=LR) | |
| losses_std = [] | |
| losses_new = [] | |
| for i in range(STEPS): | |
| x, y = get_batch(seq_len=32, vocab_size=VOCAB) | |
| # Train Standard | |
| opt_std.zero_grad() | |
| logits = model_std(x) | |
| loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1)) | |
| loss.backward() | |
| opt_std.step() | |
| losses_std.append(loss.item()) | |
| # Train New | |
| opt_new.zero_grad() | |
| logits = model_new(x) | |
| loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1)) | |
| loss.backward() | |
| opt_new.step() | |
| losses_new.append(loss.item()) | |
| if i % 50 == 0: | |
| print(f"Step {i}: Std Loss {losses_std[-1]:.3f} | New Loss {losses_new[-1]:.3f}") | |
| # Plot | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(losses_std, label='Standard Attention', alpha=0.7) | |
| plt.plot(losses_new, label='Context-Pulse Attention', linewidth=2) | |
| plt.title("Learning Speed: Isotropic vs Causal Manifold") | |
| plt.xlabel("Training Steps") | |
| plt.ylabel("Loss") | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.show() | |
| train_and_compare() |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import matplotlib.pyplot as plt
# --- 1. Locked Primitives ---
class RoPE(nn.Module):
def __init__(self, dim, max_len=2048):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
def __init__(self, dim, heads):
super().__init__()
self.n_head = heads
self.head_dim = dim // heads
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
B, H, T, Dh = x.shape
v = x.transpose(1, 2)
S = self.A - self.A.transpose(-1, -2)
flow = x + torch.matmul(x, S)
return flow
class ConvexSoftmax(nn.Module):
def forward(self, scores):
m, _ = scores.max(dim=-1, keepdim=True)
y = scores - m
ex = y.exp()
lse = m + ex.sum(dim=-1, keepdim=True).log()
return torch.exp(scores - lse)
# --- 2. Multi-Head Causal Chunked Model ---
class ContextPulseAttention(nn.Module):
def __init__(self, d_model, n_head, chunk_size=16):
super().__init__()
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.chunk_size = chunk_size
self.scale = self.head_dim ** -0.5
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
nn.init.orthogonal_(self.W_Q.weight)
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
nn.init.xavier_normal_(self.W_V.weight)
self.gate_proj = nn.Linear(d_model, d_model)
self.rope = RoPE(self.head_dim)
# FIX: Pass d_model, not head_dim, because Wedge divides internally
self.wedge = WedgeTransform(d_model, n_head)
self.softmax = ConvexSoftmax()
self.sink_scalar = nn.Parameter(torch.zeros(1, n_head, 1, 1))
self.v_null = nn.Parameter(torch.zeros(1, n_head, 1, self.head_dim))
def forward(self, x):
B, T, C = x.shape
H, Dh = self.n_head, self.head_dim
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v_val = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
num_chunks = math.ceil(T / self.chunk_size)
Q_output_list = []
history_cache = []
for i in range(num_chunks):
t_start = i * self.chunk_size
t_end = min((i + 1) * self.chunk_size, T)
q_chunk = q_raw[:, :, t_start:t_end, :]
k_chunk = k_raw[:, :, t_start:t_end, :]
gate_chunk = gate[:, :, t_start:t_end, :]
if len(history_cache) > 0:
history_stack = torch.cat(history_cache, dim=2)
resonance_scores = (k_chunk @ history_stack.transpose(-2, -1)) * self.scale
resonance_weights = F.softmax(resonance_scores, dim=-1)
q_retrieved = resonance_weights @ history_stack
base_state = history_cache[-1].expand(-1, -1, k_chunk.size(2), -1)
q_initial = base_state + q_retrieved
else:
q_initial = torch.zeros_like(k_chunk)
masked_input = gate_chunk * q_chunk
q_integrated = torch.cumsum(masked_input, dim=2)
q_chunk_out = q_integrated + q_initial
Q_output_list.append(q_chunk_out)
final_state = q_chunk_out[:, :, -1:, :]
history_cache.append(final_state)
Q_context = torch.cat(Q_output_list, dim=2)
K_pulse = k_raw
# Pointwise Symbiotic Norm
mean_q = Q_context.mean(dim=-1, keepdim=True)
std_q = Q_context.std(dim=-1, keepdim=True)
mean_k = K_pulse.mean(dim=-1, keepdim=True)
std_k = K_pulse.std(dim=-1, keepdim=True)
mean_sym = 0.5 * (mean_q + mean_k)
std_sym = 0.5 * (std_q + std_k)
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# Geometry
q_roped = self.rope(Q_context) # (B, H, T, D)
k_roped = self.rope(K_pulse)
# TRANSPOSE SANDWICH for Wedge compatibility
# Input: (B, T, H, D) -> Wedge flips to (B, H, T, D) for matmul -> Returns (B, T, H, D)
Q_geo = self.wedge(q_roped)
K_geo = self.wedge(k_roped)
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
null_scores = self.sink_scalar.expand(B, H, T, 1)
Attn_full = torch.cat([Attn, null_scores], dim=-1)
probs_full = self.softmax(Attn_full)
out = probs_full[..., :T] @ v_val + probs_full[..., T:] * self.v_null
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_O(out)
# --- 3. Multi-Head Harness ---
class ToyTransformer(nn.Module):
def __init__(self, vocab_size, d_model, n_head):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)
def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
START_TOKEN = 0
STOP_TOKEN = 1
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
def run_multihead_hard():
print("--- FIXED MULTI-HEAD GAP COPY: 4 Heads, 64 Dim ---")
VOCAB = 64
DIM = 64
HEADS = 4
STEPS = 1500
SEQ_LEN = 256
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyTransformer(VOCAB, DIM, HEADS).to(DEVICE)
opt = optim.AdamW(model.parameters(), lr=LR)
loss_hist = []
for i in range(STEPS):
x, y = get_gap_copy_batch(seq_len=SEQ_LEN, vocab_size=VOCAB, device=DEVICE)
opt.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt.step()
loss_hist.append(loss.item())
if i % 100 == 0:
print(f"Step {i}: Loss {loss.item():.4f}")
plt.figure(figsize=(10, 6))
plt.plot(loss_hist)
plt.title("Multi-Head Gap Copy (Fixed)")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.show()
if __name__ == "__main__":
run_multihead_hard()
```import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import math
--- 1. Primitives (Locked) ---
class RoPE(nn.Module):
def init(self, dim, max_len=2048):
super().init()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
seq_len = x.shape[2]
cos = self.cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.cat((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
class WedgeTransform(nn.Module):
def init(self, dim, heads):
super().init()
self.n_head = heads
self.head_dim = dim // heads
self.A = nn.Parameter(torch.zeros(heads, self.head_dim, self.head_dim))
def forward(self, x):
S = self.A - self.A.transpose(-1, -2)
S_broadcast = S.unsqueeze(0)
flow = torch.matmul(x, S_broadcast)
return x + flow
class ConvexSoftmax(nn.Module):
def forward(self, scores):
m, _ = scores.max(dim=-1, keepdim=True)
y = scores - m
ex = y.exp()
lse = m + ex.sum(dim=-1, keepdim=True).log()
return torch.exp(scores - lse)
def build_alpert_basis(block_size, poly_order=1, device='cpu'):
t = torch.linspace(-1, 1, block_size, device=device)
p0 = torch.ones_like(t)
basis_list = [p0]
if poly_order >= 1: basis_list.append(t)
if poly_order >= 2: basis_list.append(3 * t**2 - 1)
W = torch.stack(basis_list, dim=1)
W = F.normalize(W, p=2, dim=0)
return W
--- 2. Resonance-Gated Alpert Model ---
class ContextPulseAttention(nn.Module):
def init(self, d_model, n_head, chunk_size=16, poly_order=1):
super().init()
self.d_model = d_model
self.n_head = n_head
self.head_dim = d_model // n_head
self.chunk_size = chunk_size
self.poly_order = poly_order
self.scale = self.head_dim ** -0.5
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
nn.init.orthogonal_(self.W_Q.weight)
nn.init.normal_(self.W_K.weight, mean=0.0, std=0.02)
nn.init.xavier_normal_(self.W_V.weight)
self.gate_proj = nn.Linear(d_model, d_model)
# REMOVED: jump_proj (Blind Jump)
# ADDED: Resonance Sensitivity (Smart Jump)
self.jump_scale = nn.Parameter(torch.tensor(1.0))
self.jump_bias = nn.Parameter(torch.tensor(0.0)) # Start neutral/conservative
self.rope = RoPE(self.head_dim)
self.wedge = WedgeTransform(d_model, n_head)
self.softmax = ConvexSoftmax()
self.sink_scalar = nn.Parameter(torch.zeros(1, n_head, 1, 1))
self.v_null = nn.Parameter(torch.zeros(1, n_head, 1, self.head_dim))
basis = build_alpert_basis(chunk_size, poly_order)
self.register_buffer('alpert_basis', basis)
def forward(self, x):
B, T, C = x.shape
H, Dh = self.n_head, self.head_dim
q_raw = self.W_Q(x).view(B, T, H, Dh).transpose(1, 2)
k_raw = self.W_K(x).view(B, T, H, Dh).transpose(1, 2)
v_val = self.W_V(x).view(B, T, H, Dh).transpose(1, 2)
gate = torch.sigmoid(self.gate_proj(x)).view(B, T, H, Dh).transpose(1, 2)
num_chunks = math.ceil(T / self.chunk_size)
Q_output_list = []
history_coeffs_cache = []
last_chunk =torch.zeros_like(q_raw[:, :, 0:self.chunk_size, :])
W_basis = self.alpert_basis
for i in range(num_chunks):
t_start = i * self.chunk_size
t_end = min((i + 1) * self.chunk_size, T)
q_chunk = q_raw[:, :, t_start:t_end, :]
k_chunk = k_raw[:, :, t_start:t_end, :]
gate_chunk = gate[:, :, t_start:t_end, :]
current_L = q_chunk.size(2)
if current_L < self.chunk_size:
W_curr = build_alpert_basis(current_L, self.poly_order, device=x.device)
else:
W_curr = W_basis
# 1. Encode Current K (Moments)
k_coeffs = torch.einsum('bhld, lk -> bhkd', k_chunk, W_curr)
# 2. Wave Retrieval
if len(history_coeffs_cache) > 0:
history_stack = torch.cat(history_coeffs_cache, dim=2)
# Resonance (Score)
resonance_scores = torch.einsum('bhkd, bhnkd -> bhn', k_coeffs, history_stack) * self.scale
resonance_scores = resonance_scores.unsqueeze(-2) # (B, H, 1, N)
# Weights
resonance_weights = F.softmax(resonance_scores, dim=-1)
# Retrieve Moments
coeffs_retrieved = torch.einsum('bhn, bhnkd -> bhkd', resonance_weights.squeeze(2), history_stack)
# Reconstruct Trajectory
q_reconstructed = torch.einsum('bhkd, lk -> bhld', coeffs_retrieved, W_curr)
# Additive Injection
if 'last_state' in locals():
continuity_stream = last_state.expand(-1, -1, current_L, -1)
q_injected = continuity_stream + q_reconstructed
else:
q_injected = q_reconstructed
else:
q_injected = torch.zeros_like(q_chunk)
# 3. Integration
q_integrated = torch.cumsum(q_chunk, dim=2)
q_chunk_out = q_integrated + q_injected - last_chunk
Q_output_list.append(q_chunk_out)
# 4. Update Cache
q_out_coeffs = torch.einsum('bhld, lk -> bhkd', q_chunk_out, W_curr)
history_coeffs_cache.append(q_out_coeffs.unsqueeze(2))
last_state = q_chunk_out[:, :, -1:, :]
last_chunk = q_injected[:, :, -1:, :]
Q_context = torch.cat(Q_output_list, dim=2)
K_pulse = k_raw
# Pointwise Norm
mean_q = Q_context.mean(dim=-1, keepdim=True)
std_q = Q_context.std(dim=-1, keepdim=True)
mean_k = K_pulse.mean(dim=-1, keepdim=True)
std_k = K_pulse.std(dim=-1, keepdim=True)
mean_sym = 0.5 * (mean_q + mean_k)
std_sym = 0.5 * (std_q + std_k)
Q_context = (Q_context - mean_sym) / (std_sym + 1e-6)
K_pulse = (K_pulse - mean_sym) / (std_sym + 1e-6)
# Geometry
Q_geo = self.wedge(Q_context)
K_geo = self.wedge(K_pulse)
Q_geo = self.rope(Q_geo)
K_geo = self.rope(K_geo)
Attn = (Q_geo @ K_geo.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
Attn.masked_fill_(mask, float('-inf'))
null_scores = self.sink_scalar.expand(B, H, T, 1)
Attn_full = torch.cat([Attn, null_scores], dim=-1)
probs_full = self.softmax(Attn_full)
out = probs_full[..., :T] @ v_val + probs_full[..., T:] * self.v_null
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_O(out)
--- Harness ---
class ToyTransformer(nn.Module):
def init(self, vocab_size, d_model, n_head):
super().init()
self.embed = nn.Embedding(vocab_size, d_model)
self.attn = ContextPulseAttention(d_model, n_head=n_head, chunk_size=16)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
h = self.embed(x)
h = self.attn(h)
return self.fc(h)
def get_gap_copy_batch(batch_size=32, seq_len=256, vocab_size=64, device='cpu'):
START_TOKEN = 0
STOP_TOKEN = 1
data = torch.randint(2, vocab_size, (batch_size, seq_len), device=device)
pattern_len = 16
for b in range(batch_size):
pattern = torch.randint(2, vocab_size, (pattern_len,), device=device)
data[b, 0] = START_TOKEN
data[b, 1:1+pattern_len] = pattern
data[b, 1+pattern_len] = STOP_TOKEN
trigger_start = torch.randint(seq_len // 2, seq_len - pattern_len - 1, (1,)).item()
data[b, trigger_start] = START_TOKEN
data[b, trigger_start+1 : trigger_start+1+pattern_len] = pattern
inputs = data[:, :-1]
targets = data[:, 1:]
return inputs, targets
def run_resonant_jump():
print("--- ALPERT + RESONANCE-DRIVEN JUMP: 4 Heads, 64 Dim ---")
VOCAB = 64
DIM = 128
HEADS = 4
STEPS = 1500
SEQ_LEN = 256
LR = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ToyTransformer(VOCAB, DIM, HEADS).to(DEVICE)
opt = optim.AdamW(model.parameters(), lr=LR)
loss_hist = []
for i in range(STEPS):
x, y = get_gap_copy_batch(seq_len=SEQ_LEN, vocab_size=VOCAB, device=DEVICE)
opt.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), y.reshape(-1))
loss.backward()
opt.step()
loss_hist.append(loss.item())
if i % 100 == 0:
print(f"Step {i}: Loss {loss.item():.4f}")
plt.figure(figsize=(10, 6))
plt.plot(loss_hist)
plt.title("Resonance-Gated Alpert Retrieval")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)
plt.show()
if name == "main":
run_resonant_jump()
Uh oh!
There was an error while loading. Please reload this page.