Skip to content

Instantly share code, notes, and snippets.

@falseywinchnet
Last active December 31, 2025 02:53
Show Gist options
  • Select an option

  • Save falseywinchnet/4897700905567a98af7d99b9fe6e001d to your computer and use it in GitHub Desktop.

Select an option

Save falseywinchnet/4897700905567a98af7d99b9fe6e001d to your computer and use it in GitHub Desktop.
#copyright joshuah.rainstar@gmail.com 2025
#MIT with attribution
#getting real TIRED of the FAGS from samsung and elsewhere declaring they have "reasoning"
#models just because they reuse a set of weights and learn a state space system
#attention is bayesian coordinate transport to begin with
#they declare "oh we do it with less params" yes- and more compute.
#you added crap like convolution because you still have no idea what the fuck is going on
#i wish you didnt get any funding and your ancestors came back to life to beat you,
#TRM, HRM, URM programmers- no you dont get to be called researchers, you're too retarded for that
#anyway here's what amounts to a little bit more of a reasoning module go nuts
import math
import copy
from dataclasses import dataclass
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
class RoPE(nn.Module):
def __init__(self, dim, max_len=4096):
super().__init__()
assert dim % 2 == 0
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_len).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
self.register_buffer('cos', freqs.cos())
self.register_buffer('sin', freqs.sin())
def forward(self, x):
# x: (B, *, T, D)
T = x.shape[-2]
cos = self.cos[:T, :].unsqueeze(0).unsqueeze(0) # [1,1,T,D/2]
sin = self.sin[:T, :].unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
y1 = x1 * cos - x2 * sin
y2 = x1 * sin + x2 * cos
return torch.stack((y1, y2), dim=-1).flatten(-2)
class Attention(nn.Module):
def __init__(
self,
d_model: int,
n_branch: int,
palette_hw: int = 32,
max_k: int = 12,
rel_hid: int = 64,
bias: bool = False,
use_delta: bool = True,
):
super().__init__()
assert d_model % 2 == 0, "RoPE needs even dim"
self.d = d_model
self.BR = n_branch
self.Ph = self.Pw = palette_hw
self.Kmax = max_k
self.use_delta = use_delta
# Generate B full-width impressions per token: [d] -> [B*d]
self.i_proj = nn.Linear(d_model, n_branch * d_model, bias=bias)
# One shared potential per token: [d] -> [d]
self.p_proj = nn.Linear(d_model, d_model, bias=bias)
self.rope = RoPE(d_model)
# Persistent palette: [d, Ph, Pw] shared across branches
self.palette = nn.Parameter(torch.randn(d_model, self.Ph, self.Pw) * (d_model ** -0.5))
# Relational feature MLP: per node i in constellation:
# features = [G_row (Kmax), anchor_dot (1), (optional) delta (1)]
rel_in = self.Kmax + 1 + (1 if use_delta else 0)
self.rel_mlp = nn.Sequential(
nn.Linear(rel_in, rel_hid),
nn.GELU(),
nn.Linear(rel_hid, rel_hid),
nn.GELU(),
)
self.coord_head = nn.Linear(rel_hid, 2) # z in [-1,1]^2
self.mix_head = nn.Linear(rel_hid, 1) # mix logits over constellation
# Branch-local output projections W_O^{(b)}: [BR, d, d]
self.Wo = nn.Parameter(torch.randn(n_branch, d_model, d_model) * (d_model ** -0.5))
def forward(self, x, attn_mask=None, strict_causal: bool = False, q_chunk: int = 64, k_chunk: int = 256, eps: float = 1e-9):
"""
attn_mask: optional bool mask broadcastable to [B, 1, T, T] (True=allowed).
strict_causal False means allow self (recommended if you require min support=1).
"""
B, T, D = x.shape
device = x.device
inv_sqrt = 1.0 / math.sqrt(D)
Kmax = self.Kmax
# I: [B, BR, T, D]
I = self.i_proj(x).view(B, T, self.BR, D).permute(0, 2, 1, 3).contiguous()
# P: [B, T, D]
P = self.p_proj(x)
# RoPE
I = self.rope(I) # [B,BR,T,D]
P = self.rope(P.unsqueeze(1)).squeeze(1) # [B,T,D]
# Normalize mask to [B,1,T,T]
if attn_mask is not None:
if attn_mask.dim() == 2:
attn_mask = attn_mask[None, None, :, :]
elif attn_mask.dim() == 3:
attn_mask = attn_mask[:, None, :, :]
# Palette expanded for grid_sample: [B*BR, D, Ph, Pw]
pal = self.palette.unsqueeze(0).expand(B * self.BR, -1, -1, -1)
# -----------------------------
# Pass A: streaming causal top-k
# -----------------------------
topk_idx = torch.zeros((B, self.BR, T, Kmax), device=device, dtype=torch.long)
topk_val = x.new_full((B, self.BR, T, Kmax), float("-inf"))
for t0 in range(0, T, q_chunk):
t1 = min(T, t0 + q_chunk)
tq = t1 - t0
I_blk = I[:, :, t0:t1, :] # [B,BR,tq,D]
best_vals = x.new_full((B, self.BR, tq, Kmax), float("-inf"))
best_idx = torch.zeros((B, self.BR, tq, Kmax), device=device, dtype=torch.long)
for s0 in range(0, T, k_chunk):
s1 = min(T, s0 + k_chunk)
sk = s1 - s0
P_blk = P[:, s0:s1, :] # [B,sk,D]
# logits: [B,BR,tq,sk]
logits = torch.einsum("nrtd,nsd->nrts", I_blk, P_blk) * inv_sqrt
qpos = torch.arange(t0, t1, device=device).view(1, 1, tq, 1)
kpos = torch.arange(s0, s1, device=device).view(1, 1, 1, sk)
allow = (kpos < qpos) if strict_causal else (kpos <= qpos)
# if strict_causal, t=0 has no allowed keys; but you said min support=1.
# So force-allow self at t=0 if strict:
if strict_causal and t0 == 0:
allow = allow | ((qpos == 0) & (kpos == 0))
if attn_mask is not None:
allow = allow & attn_mask[:, :, t0:t1, s0:s1]
logits = logits.masked_fill(~allow, float("-inf"))
cand_k = min(Kmax, sk)
cand_vals, cand_pos = torch.topk(logits, k=cand_k, dim=-1)
cand_idx = cand_pos + s0
merged_vals = torch.cat([best_vals, cand_vals], dim=-1)
merged_idx = torch.cat([best_idx, cand_idx], dim=-1)
best_vals, sel = torch.topk(merged_vals, k=Kmax, dim=-1)
best_idx = torch.gather(merged_idx, dim=-1, index=sel)
topk_val[:, :, t0:t1, :] = best_vals
topk_idx[:, :, t0:t1, :] = best_idx
# -----------------------------
# Pass B: relational manifold → palette → V^{(b)}
# -----------------------------
V = x.new_zeros((B, self.BR, T, D))
Pn = F.normalize(P, dim=-1) # [B,T,D]
for t0 in range(0, T, q_chunk):
t1 = min(T, t0 + q_chunk)
tq = t1 - t0
I_blk = I[:, :, t0:t1, :] # [B,BR,tq,D]
In_blk = F.normalize(I_blk, dim=-1)
idx = topk_idx[:, :, t0:t1, :] # [B,BR,tq,Kmax]
# k_eff = min(Kmax, history_len), enforce min 1
t_abs = torch.arange(t0, t1, device=device).view(1, 1, tq, 1)
hist = (t_abs if strict_causal else (t_abs + 1)).clamp(min=1, max=Kmax)
kk = torch.arange(Kmax, device=device).view(1, 1, 1, Kmax)
keep = kk < hist # [1,1,tq,Kmax]
keep_f = keep.float()
# gather selected P and normalized P
P_sel = torch.gather(
P[:, None, None, :, :].expand(B, self.BR, tq, T, D),
3, idx[..., None].expand(B, self.BR, tq, Kmax, D)
) # [B,BR,tq,K,D]
Pn_sel = torch.gather(
Pn[:, None, None, :, :].expand(B, self.BR, tq, T, D),
3, idx[..., None].expand(B, self.BR, tq, Kmax, D)
)
# anchor alignment (bounded)
# In_blk: [n,r,t,d], Pn_sel: [n,r,t,k,d] -> a: [n,r,t,k]
a = torch.einsum("nrtd,nrtkd->nrtk", In_blk, Pn_sel).clamp(-1.0, 1.0)
# Pn_sel: [n,r,t,k,d] -> G: [n,r,t,k,k]
G = torch.einsum("nrtid,nrtjd->nrtij", Pn_sel, Pn_sel).clamp(-1.0, 1.0)
# structural masking (NO -inf features!)
G = G * keep_f.unsqueeze(-1) * keep_f.unsqueeze(-2)
a = a * keep_f
feats = [G, a.unsqueeze(-1)]
if self.use_delta:
delta = (t_abs - idx).float().clamp_min(0.0) / max(1.0, float(T)) # [B,BR,tq,K]
delta = delta * keep_f
feats.append(delta.unsqueeze(-1))
rel = torch.cat(feats, dim=-1) # [B,BR,tq,K, K+1(+1)]
# MLP → coords + mix logits
h = self.rel_mlp(rel)
z = torch.tanh(self.coord_head(h)) # [B,BR,tq,K,2]
mix_logits = self.mix_head(h).squeeze(-1) # [B,BR,tq,K]
mix_logits = mix_logits.masked_fill(~keep, float("-inf"))
w = torch.softmax(mix_logits, dim=-1) # [B,BR,tq,K]
# palette sample at constellation positions
grid = z.reshape(B * self.BR, tq, Kmax, 2)
samp = F.grid_sample(pal, grid, mode="bilinear", padding_mode="border", align_corners=True)
samp = samp.view(B, self.BR, D, tq, Kmax).permute(0, 1, 3, 4, 2) # [B,BR,tq,K,D]
v = torch.einsum("nrtk,nrtkd->nrtd", w, samp)
V[:, :, t0:t1, :] = v
# Branch-specific W_O then mean across branches:
# y_branch = V @ Wo[branch]
y_br = torch.einsum("nrtd,rdm->nrtm", V, self.Wo)
y = y_br.mean(dim=1) # [B,T,D]
return y
class VectorizedConstellationAttention(nn.Module):
def __init__(
self,
d_model: int,
n_branch: int,
palette_hw: int = 32,
max_k: int = 12,
rel_hid: int = 64,
bias: bool = False,
use_delta: bool = True,
rope_max_len: int = 4096,
):
super().__init__()
assert d_model % 2 == 0
self.d = d_model
self.BR = n_branch
self.Ph = self.Pw = palette_hw
self.Kmax = max_k
self.use_delta = use_delta
# Projections
self.i_proj = nn.Linear(d_model, n_branch * d_model, bias=bias)
self.p_proj = nn.Linear(d_model, d_model, bias=bias)
# Use YOUR fixed RoPE everywhere (Q and K)
self.rope = RoPE(d_model, max_len=rope_max_len)
# Persistent palette: [D, Ph, Pw]
self.palette = nn.Parameter(torch.randn(d_model, self.Ph, self.Pw) * (d_model ** -0.5))
# Relational MLP: input = G_row (K) + anchor_dot (1) + optional delta (1)
rel_in = self.Kmax + 1 + (1 if use_delta else 0)
self.rel_mlp = nn.Sequential(
nn.Linear(rel_in, rel_hid),
nn.GELU(),
nn.Linear(rel_hid, rel_hid),
nn.GELU(),
)
self.coord_head = nn.Linear(rel_hid, 2)
self.mix_head = nn.Linear(rel_hid, 1)
# Output projection
self.Wo = nn.Parameter(torch.randn(n_branch, d_model, d_model) * (d_model ** -0.5))
def _normalize_attn_mask(self, attn_mask, B, T, device):
"""
Returns bool mask broadcastable to [B,1,T,T], True = allowed.
Supported:
- None
- [B,T] key padding (True = keep key)
- [B,T,T] or [B,1,T,T]
"""
if attn_mask is None:
return None
if attn_mask.dtype != torch.bool:
attn_mask = attn_mask.bool()
if attn_mask.dim() == 2:
# key padding: [B,T] -> [B,1,T,T] (broadcast over queries)
attn_mask = attn_mask[:, None, None, :].expand(B, 1, T, T)
elif attn_mask.dim() == 3:
attn_mask = attn_mask[:, None, :, :]
elif attn_mask.dim() == 4:
pass
else:
raise ValueError(f"attn_mask must be None, [B,T], [B,T,T], or [B,1,T,T]; got {attn_mask.shape}")
return attn_mask.to(device)
def forward(self, x, attn_mask=None):
B, T, D = x.shape
K = self.Kmax
scale = D ** -0.5
device = x.device
# ------------------------------------------------------------
# 1) Projections + FIXED RoPE
# ------------------------------------------------------------
I = self.i_proj(x).view(B, T, self.BR, D).transpose(1, 2).contiguous() # [B,BR,T,D]
P = self.p_proj(x) # [B,T,D]
I = self.rope(I) # [B,BR,T,D]
P = self.rope(P.unsqueeze(1)).squeeze(1) # [B,T,D]
# ------------------------------------------------------------
# 2) Pass A: logits + causal + optional mask + topk
# ------------------------------------------------------------
logits = torch.matmul(I, P.unsqueeze(1).transpose(-1, -2)) * scale # [B,BR,T,T]
causal = torch.tril(torch.ones((T, T), device=device, dtype=torch.bool)).view(1, 1, T, T)
allow = causal
attn_mask = self._normalize_attn_mask(attn_mask, B, T, device)
if attn_mask is not None:
allow = allow & attn_mask
logits = logits.masked_fill(~allow, float("-inf"))
k_eff = min(K, T)
topk_val, topk_idx = torch.topk(logits, k=k_eff, dim=-1) # [B,BR,T,k_eff]
# Pad to K so downstream dims stay constant (MLP expects K features)
if k_eff < K:
pad = K - k_eff
topk_val = torch.cat([topk_val, topk_val.new_full((B, self.BR, T, pad), float("-inf"))], dim=-1)
topk_idx = torch.cat([topk_idx, topk_idx.new_zeros((B, self.BR, T, pad))], dim=-1)
# Valid neighbors: avoid poisoning G/z/w with -inf slots
keep = torch.isfinite(topk_val) # [B,BR,T,K]
keep_f = keep.float()
# If some query has *no* valid neighbor (can happen if user mask nukes all),
# force self as a fallback.
all_bad = ~keep.any(dim=-1, keepdim=True) # [B,BR,T,1]
if all_bad.any():
t_idx = torch.arange(T, device=device).view(1, 1, T, 1).expand(B, self.BR, T, 1)
topk_idx = torch.where(all_bad, t_idx, topk_idx)
topk_val = torch.where(all_bad, topk_val.new_zeros(()), topk_val) # give it a finite score
keep = torch.isfinite(topk_val)
keep_f = keep.float()
# ------------------------------------------------------------
# 3) Gather neighbors + relational features (masked)
# ------------------------------------------------------------
b_idx = torch.arange(B, device=device).view(B, 1, 1, 1)
P_sel = P[b_idx, topk_idx] # [B,BR,T,K,D]
P_sel_norm = F.normalize(P_sel, dim=-1) * keep_f.unsqueeze(-1)
I_norm = F.normalize(I, dim=-1).unsqueeze(3) # [B,BR,T,1,D]
feat_a = (I_norm * P_sel_norm).sum(dim=-1).clamp(-1.0, 1.0)
feat_a = feat_a * keep_f # [B,BR,T,K]
G = torch.matmul(P_sel_norm, P_sel_norm.transpose(-1, -2)).clamp(-1.0, 1.0) # [B,BR,T,K,K]
G = G * keep_f.unsqueeze(-1) * keep_f.unsqueeze(-2)
feats = [G, feat_a.unsqueeze(-1)]
if self.use_delta:
t_range = torch.arange(T, device=device).view(1, 1, T, 1)
delta = (t_range - topk_idx).float().clamp_min(0.0) / max(1.0, float(T))
delta = delta * keep_f
feats.append(delta.unsqueeze(-1))
rel_input = torch.cat(feats, dim=-1) # [B,BR,T,K, K+1(+1)]
# ------------------------------------------------------------
# 4) MLP -> coords + weights (masked softmax)
# ------------------------------------------------------------
h = self.rel_mlp(rel_input) # [B,BR,T,K,rel_hid]
z = torch.tanh(self.coord_head(h)) # [B,BR,T,K,2]
mix_logits = self.mix_head(h).squeeze(-1) # [B,BR,T,K]
mix_logits = mix_logits.masked_fill(~keep, float("-inf"))
w = torch.softmax(mix_logits, dim=-1)
w = torch.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0) # extra safety
# ------------------------------------------------------------
# 5) Sample palette + aggregate
# ------------------------------------------------------------
batch_pal = self.palette.unsqueeze(0).expand(B * self.BR, -1, -1, -1) # [B*BR,D,Ph,Pw]
grid = z.reshape(B * self.BR, T, K, 2) # [B*BR,T,K,2]
samples = F.grid_sample(
batch_pal,
grid,
mode="bilinear",
padding_mode="border",
align_corners=True,
) # [B*BR, D, T, K]
samples = samples.view(B, self.BR, D, T, K).permute(0, 1, 3, 4, 2) # [B,BR,T,K,D]
V_out = (samples * w.unsqueeze(-1)).sum(dim=3) # [B,BR,T,D]
y = torch.einsum("nrtd,rdm->nrtm", V_out, self.Wo).mean(dim=1) # [B,T,D]
return y
class LayerNorm(nn.Module):
def __init__(self, ndim: int, bias: bool = True):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.use_bias = bias
if bias:
self.bias = nn.Parameter(torch.zeros(ndim))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b = self.bias if self.use_bias else None
return F.layer_norm(x, self.weight.shape, self.weight, b, 1e-5)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear( config.n_embd,4* config.n_embd, bias=config.bias)
self.scale = math.pi / math.sqrt(3.0)
self.ln = LayerNorm(config.n_embd*4, bias=config.bias)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = x * torch.sigmoid(self.scale * x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = VectorizedConstellationAttention(config.n_embd,config.n_head)
self.mlp = MLP(config)
def forward(self,x):
B, T, C = x.shape
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.0
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
# Base noise seed (learned) for map generation
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
def get_num_params(self, non_embedding=True):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
The token embeddings would too, except due to the parameter sharing these
params are actually used as weights in the final layer, so we include them.
"""
n_params = sum(p.numel() for p in self.parameters())
return n_params
def forward(self, idx, targets=None):
device = idx.device
b, T = idx.size()
x = self.transformer.wte(idx) # token
# forward the GPT model itself
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
return logits, loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment