Last active
December 31, 2025 02:53
-
-
Save falseywinchnet/4897700905567a98af7d99b9fe6e001d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #copyright joshuah.rainstar@gmail.com 2025 | |
| #MIT with attribution | |
| #getting real TIRED of the FAGS from samsung and elsewhere declaring they have "reasoning" | |
| #models just because they reuse a set of weights and learn a state space system | |
| #attention is bayesian coordinate transport to begin with | |
| #they declare "oh we do it with less params" yes- and more compute. | |
| #you added crap like convolution because you still have no idea what the fuck is going on | |
| #i wish you didnt get any funding and your ancestors came back to life to beat you, | |
| #TRM, HRM, URM programmers- no you dont get to be called researchers, you're too retarded for that | |
| #anyway here's what amounts to a little bit more of a reasoning module go nuts | |
| import math | |
| import copy | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class RoPE(nn.Module): | |
| def __init__(self, dim, max_len=4096): | |
| super().__init__() | |
| assert dim % 2 == 0 | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| t = torch.arange(max_len).float() | |
| freqs = torch.einsum('i,j->ij', t, inv_freq) | |
| self.register_buffer('cos', freqs.cos()) | |
| self.register_buffer('sin', freqs.sin()) | |
| def forward(self, x): | |
| # x: (B, *, T, D) | |
| T = x.shape[-2] | |
| cos = self.cos[:T, :].unsqueeze(0).unsqueeze(0) # [1,1,T,D/2] | |
| sin = self.sin[:T, :].unsqueeze(0).unsqueeze(0) | |
| x1 = x[..., 0::2] | |
| x2 = x[..., 1::2] | |
| y1 = x1 * cos - x2 * sin | |
| y2 = x1 * sin + x2 * cos | |
| return torch.stack((y1, y2), dim=-1).flatten(-2) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_branch: int, | |
| palette_hw: int = 32, | |
| max_k: int = 12, | |
| rel_hid: int = 64, | |
| bias: bool = False, | |
| use_delta: bool = True, | |
| ): | |
| super().__init__() | |
| assert d_model % 2 == 0, "RoPE needs even dim" | |
| self.d = d_model | |
| self.BR = n_branch | |
| self.Ph = self.Pw = palette_hw | |
| self.Kmax = max_k | |
| self.use_delta = use_delta | |
| # Generate B full-width impressions per token: [d] -> [B*d] | |
| self.i_proj = nn.Linear(d_model, n_branch * d_model, bias=bias) | |
| # One shared potential per token: [d] -> [d] | |
| self.p_proj = nn.Linear(d_model, d_model, bias=bias) | |
| self.rope = RoPE(d_model) | |
| # Persistent palette: [d, Ph, Pw] shared across branches | |
| self.palette = nn.Parameter(torch.randn(d_model, self.Ph, self.Pw) * (d_model ** -0.5)) | |
| # Relational feature MLP: per node i in constellation: | |
| # features = [G_row (Kmax), anchor_dot (1), (optional) delta (1)] | |
| rel_in = self.Kmax + 1 + (1 if use_delta else 0) | |
| self.rel_mlp = nn.Sequential( | |
| nn.Linear(rel_in, rel_hid), | |
| nn.GELU(), | |
| nn.Linear(rel_hid, rel_hid), | |
| nn.GELU(), | |
| ) | |
| self.coord_head = nn.Linear(rel_hid, 2) # z in [-1,1]^2 | |
| self.mix_head = nn.Linear(rel_hid, 1) # mix logits over constellation | |
| # Branch-local output projections W_O^{(b)}: [BR, d, d] | |
| self.Wo = nn.Parameter(torch.randn(n_branch, d_model, d_model) * (d_model ** -0.5)) | |
| def forward(self, x, attn_mask=None, strict_causal: bool = False, q_chunk: int = 64, k_chunk: int = 256, eps: float = 1e-9): | |
| """ | |
| attn_mask: optional bool mask broadcastable to [B, 1, T, T] (True=allowed). | |
| strict_causal False means allow self (recommended if you require min support=1). | |
| """ | |
| B, T, D = x.shape | |
| device = x.device | |
| inv_sqrt = 1.0 / math.sqrt(D) | |
| Kmax = self.Kmax | |
| # I: [B, BR, T, D] | |
| I = self.i_proj(x).view(B, T, self.BR, D).permute(0, 2, 1, 3).contiguous() | |
| # P: [B, T, D] | |
| P = self.p_proj(x) | |
| # RoPE | |
| I = self.rope(I) # [B,BR,T,D] | |
| P = self.rope(P.unsqueeze(1)).squeeze(1) # [B,T,D] | |
| # Normalize mask to [B,1,T,T] | |
| if attn_mask is not None: | |
| if attn_mask.dim() == 2: | |
| attn_mask = attn_mask[None, None, :, :] | |
| elif attn_mask.dim() == 3: | |
| attn_mask = attn_mask[:, None, :, :] | |
| # Palette expanded for grid_sample: [B*BR, D, Ph, Pw] | |
| pal = self.palette.unsqueeze(0).expand(B * self.BR, -1, -1, -1) | |
| # ----------------------------- | |
| # Pass A: streaming causal top-k | |
| # ----------------------------- | |
| topk_idx = torch.zeros((B, self.BR, T, Kmax), device=device, dtype=torch.long) | |
| topk_val = x.new_full((B, self.BR, T, Kmax), float("-inf")) | |
| for t0 in range(0, T, q_chunk): | |
| t1 = min(T, t0 + q_chunk) | |
| tq = t1 - t0 | |
| I_blk = I[:, :, t0:t1, :] # [B,BR,tq,D] | |
| best_vals = x.new_full((B, self.BR, tq, Kmax), float("-inf")) | |
| best_idx = torch.zeros((B, self.BR, tq, Kmax), device=device, dtype=torch.long) | |
| for s0 in range(0, T, k_chunk): | |
| s1 = min(T, s0 + k_chunk) | |
| sk = s1 - s0 | |
| P_blk = P[:, s0:s1, :] # [B,sk,D] | |
| # logits: [B,BR,tq,sk] | |
| logits = torch.einsum("nrtd,nsd->nrts", I_blk, P_blk) * inv_sqrt | |
| qpos = torch.arange(t0, t1, device=device).view(1, 1, tq, 1) | |
| kpos = torch.arange(s0, s1, device=device).view(1, 1, 1, sk) | |
| allow = (kpos < qpos) if strict_causal else (kpos <= qpos) | |
| # if strict_causal, t=0 has no allowed keys; but you said min support=1. | |
| # So force-allow self at t=0 if strict: | |
| if strict_causal and t0 == 0: | |
| allow = allow | ((qpos == 0) & (kpos == 0)) | |
| if attn_mask is not None: | |
| allow = allow & attn_mask[:, :, t0:t1, s0:s1] | |
| logits = logits.masked_fill(~allow, float("-inf")) | |
| cand_k = min(Kmax, sk) | |
| cand_vals, cand_pos = torch.topk(logits, k=cand_k, dim=-1) | |
| cand_idx = cand_pos + s0 | |
| merged_vals = torch.cat([best_vals, cand_vals], dim=-1) | |
| merged_idx = torch.cat([best_idx, cand_idx], dim=-1) | |
| best_vals, sel = torch.topk(merged_vals, k=Kmax, dim=-1) | |
| best_idx = torch.gather(merged_idx, dim=-1, index=sel) | |
| topk_val[:, :, t0:t1, :] = best_vals | |
| topk_idx[:, :, t0:t1, :] = best_idx | |
| # ----------------------------- | |
| # Pass B: relational manifold → palette → V^{(b)} | |
| # ----------------------------- | |
| V = x.new_zeros((B, self.BR, T, D)) | |
| Pn = F.normalize(P, dim=-1) # [B,T,D] | |
| for t0 in range(0, T, q_chunk): | |
| t1 = min(T, t0 + q_chunk) | |
| tq = t1 - t0 | |
| I_blk = I[:, :, t0:t1, :] # [B,BR,tq,D] | |
| In_blk = F.normalize(I_blk, dim=-1) | |
| idx = topk_idx[:, :, t0:t1, :] # [B,BR,tq,Kmax] | |
| # k_eff = min(Kmax, history_len), enforce min 1 | |
| t_abs = torch.arange(t0, t1, device=device).view(1, 1, tq, 1) | |
| hist = (t_abs if strict_causal else (t_abs + 1)).clamp(min=1, max=Kmax) | |
| kk = torch.arange(Kmax, device=device).view(1, 1, 1, Kmax) | |
| keep = kk < hist # [1,1,tq,Kmax] | |
| keep_f = keep.float() | |
| # gather selected P and normalized P | |
| P_sel = torch.gather( | |
| P[:, None, None, :, :].expand(B, self.BR, tq, T, D), | |
| 3, idx[..., None].expand(B, self.BR, tq, Kmax, D) | |
| ) # [B,BR,tq,K,D] | |
| Pn_sel = torch.gather( | |
| Pn[:, None, None, :, :].expand(B, self.BR, tq, T, D), | |
| 3, idx[..., None].expand(B, self.BR, tq, Kmax, D) | |
| ) | |
| # anchor alignment (bounded) | |
| # In_blk: [n,r,t,d], Pn_sel: [n,r,t,k,d] -> a: [n,r,t,k] | |
| a = torch.einsum("nrtd,nrtkd->nrtk", In_blk, Pn_sel).clamp(-1.0, 1.0) | |
| # Pn_sel: [n,r,t,k,d] -> G: [n,r,t,k,k] | |
| G = torch.einsum("nrtid,nrtjd->nrtij", Pn_sel, Pn_sel).clamp(-1.0, 1.0) | |
| # structural masking (NO -inf features!) | |
| G = G * keep_f.unsqueeze(-1) * keep_f.unsqueeze(-2) | |
| a = a * keep_f | |
| feats = [G, a.unsqueeze(-1)] | |
| if self.use_delta: | |
| delta = (t_abs - idx).float().clamp_min(0.0) / max(1.0, float(T)) # [B,BR,tq,K] | |
| delta = delta * keep_f | |
| feats.append(delta.unsqueeze(-1)) | |
| rel = torch.cat(feats, dim=-1) # [B,BR,tq,K, K+1(+1)] | |
| # MLP → coords + mix logits | |
| h = self.rel_mlp(rel) | |
| z = torch.tanh(self.coord_head(h)) # [B,BR,tq,K,2] | |
| mix_logits = self.mix_head(h).squeeze(-1) # [B,BR,tq,K] | |
| mix_logits = mix_logits.masked_fill(~keep, float("-inf")) | |
| w = torch.softmax(mix_logits, dim=-1) # [B,BR,tq,K] | |
| # palette sample at constellation positions | |
| grid = z.reshape(B * self.BR, tq, Kmax, 2) | |
| samp = F.grid_sample(pal, grid, mode="bilinear", padding_mode="border", align_corners=True) | |
| samp = samp.view(B, self.BR, D, tq, Kmax).permute(0, 1, 3, 4, 2) # [B,BR,tq,K,D] | |
| v = torch.einsum("nrtk,nrtkd->nrtd", w, samp) | |
| V[:, :, t0:t1, :] = v | |
| # Branch-specific W_O then mean across branches: | |
| # y_branch = V @ Wo[branch] | |
| y_br = torch.einsum("nrtd,rdm->nrtm", V, self.Wo) | |
| y = y_br.mean(dim=1) # [B,T,D] | |
| return y | |
| class VectorizedConstellationAttention(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_branch: int, | |
| palette_hw: int = 32, | |
| max_k: int = 12, | |
| rel_hid: int = 64, | |
| bias: bool = False, | |
| use_delta: bool = True, | |
| rope_max_len: int = 4096, | |
| ): | |
| super().__init__() | |
| assert d_model % 2 == 0 | |
| self.d = d_model | |
| self.BR = n_branch | |
| self.Ph = self.Pw = palette_hw | |
| self.Kmax = max_k | |
| self.use_delta = use_delta | |
| # Projections | |
| self.i_proj = nn.Linear(d_model, n_branch * d_model, bias=bias) | |
| self.p_proj = nn.Linear(d_model, d_model, bias=bias) | |
| # Use YOUR fixed RoPE everywhere (Q and K) | |
| self.rope = RoPE(d_model, max_len=rope_max_len) | |
| # Persistent palette: [D, Ph, Pw] | |
| self.palette = nn.Parameter(torch.randn(d_model, self.Ph, self.Pw) * (d_model ** -0.5)) | |
| # Relational MLP: input = G_row (K) + anchor_dot (1) + optional delta (1) | |
| rel_in = self.Kmax + 1 + (1 if use_delta else 0) | |
| self.rel_mlp = nn.Sequential( | |
| nn.Linear(rel_in, rel_hid), | |
| nn.GELU(), | |
| nn.Linear(rel_hid, rel_hid), | |
| nn.GELU(), | |
| ) | |
| self.coord_head = nn.Linear(rel_hid, 2) | |
| self.mix_head = nn.Linear(rel_hid, 1) | |
| # Output projection | |
| self.Wo = nn.Parameter(torch.randn(n_branch, d_model, d_model) * (d_model ** -0.5)) | |
| def _normalize_attn_mask(self, attn_mask, B, T, device): | |
| """ | |
| Returns bool mask broadcastable to [B,1,T,T], True = allowed. | |
| Supported: | |
| - None | |
| - [B,T] key padding (True = keep key) | |
| - [B,T,T] or [B,1,T,T] | |
| """ | |
| if attn_mask is None: | |
| return None | |
| if attn_mask.dtype != torch.bool: | |
| attn_mask = attn_mask.bool() | |
| if attn_mask.dim() == 2: | |
| # key padding: [B,T] -> [B,1,T,T] (broadcast over queries) | |
| attn_mask = attn_mask[:, None, None, :].expand(B, 1, T, T) | |
| elif attn_mask.dim() == 3: | |
| attn_mask = attn_mask[:, None, :, :] | |
| elif attn_mask.dim() == 4: | |
| pass | |
| else: | |
| raise ValueError(f"attn_mask must be None, [B,T], [B,T,T], or [B,1,T,T]; got {attn_mask.shape}") | |
| return attn_mask.to(device) | |
| def forward(self, x, attn_mask=None): | |
| B, T, D = x.shape | |
| K = self.Kmax | |
| scale = D ** -0.5 | |
| device = x.device | |
| # ------------------------------------------------------------ | |
| # 1) Projections + FIXED RoPE | |
| # ------------------------------------------------------------ | |
| I = self.i_proj(x).view(B, T, self.BR, D).transpose(1, 2).contiguous() # [B,BR,T,D] | |
| P = self.p_proj(x) # [B,T,D] | |
| I = self.rope(I) # [B,BR,T,D] | |
| P = self.rope(P.unsqueeze(1)).squeeze(1) # [B,T,D] | |
| # ------------------------------------------------------------ | |
| # 2) Pass A: logits + causal + optional mask + topk | |
| # ------------------------------------------------------------ | |
| logits = torch.matmul(I, P.unsqueeze(1).transpose(-1, -2)) * scale # [B,BR,T,T] | |
| causal = torch.tril(torch.ones((T, T), device=device, dtype=torch.bool)).view(1, 1, T, T) | |
| allow = causal | |
| attn_mask = self._normalize_attn_mask(attn_mask, B, T, device) | |
| if attn_mask is not None: | |
| allow = allow & attn_mask | |
| logits = logits.masked_fill(~allow, float("-inf")) | |
| k_eff = min(K, T) | |
| topk_val, topk_idx = torch.topk(logits, k=k_eff, dim=-1) # [B,BR,T,k_eff] | |
| # Pad to K so downstream dims stay constant (MLP expects K features) | |
| if k_eff < K: | |
| pad = K - k_eff | |
| topk_val = torch.cat([topk_val, topk_val.new_full((B, self.BR, T, pad), float("-inf"))], dim=-1) | |
| topk_idx = torch.cat([topk_idx, topk_idx.new_zeros((B, self.BR, T, pad))], dim=-1) | |
| # Valid neighbors: avoid poisoning G/z/w with -inf slots | |
| keep = torch.isfinite(topk_val) # [B,BR,T,K] | |
| keep_f = keep.float() | |
| # If some query has *no* valid neighbor (can happen if user mask nukes all), | |
| # force self as a fallback. | |
| all_bad = ~keep.any(dim=-1, keepdim=True) # [B,BR,T,1] | |
| if all_bad.any(): | |
| t_idx = torch.arange(T, device=device).view(1, 1, T, 1).expand(B, self.BR, T, 1) | |
| topk_idx = torch.where(all_bad, t_idx, topk_idx) | |
| topk_val = torch.where(all_bad, topk_val.new_zeros(()), topk_val) # give it a finite score | |
| keep = torch.isfinite(topk_val) | |
| keep_f = keep.float() | |
| # ------------------------------------------------------------ | |
| # 3) Gather neighbors + relational features (masked) | |
| # ------------------------------------------------------------ | |
| b_idx = torch.arange(B, device=device).view(B, 1, 1, 1) | |
| P_sel = P[b_idx, topk_idx] # [B,BR,T,K,D] | |
| P_sel_norm = F.normalize(P_sel, dim=-1) * keep_f.unsqueeze(-1) | |
| I_norm = F.normalize(I, dim=-1).unsqueeze(3) # [B,BR,T,1,D] | |
| feat_a = (I_norm * P_sel_norm).sum(dim=-1).clamp(-1.0, 1.0) | |
| feat_a = feat_a * keep_f # [B,BR,T,K] | |
| G = torch.matmul(P_sel_norm, P_sel_norm.transpose(-1, -2)).clamp(-1.0, 1.0) # [B,BR,T,K,K] | |
| G = G * keep_f.unsqueeze(-1) * keep_f.unsqueeze(-2) | |
| feats = [G, feat_a.unsqueeze(-1)] | |
| if self.use_delta: | |
| t_range = torch.arange(T, device=device).view(1, 1, T, 1) | |
| delta = (t_range - topk_idx).float().clamp_min(0.0) / max(1.0, float(T)) | |
| delta = delta * keep_f | |
| feats.append(delta.unsqueeze(-1)) | |
| rel_input = torch.cat(feats, dim=-1) # [B,BR,T,K, K+1(+1)] | |
| # ------------------------------------------------------------ | |
| # 4) MLP -> coords + weights (masked softmax) | |
| # ------------------------------------------------------------ | |
| h = self.rel_mlp(rel_input) # [B,BR,T,K,rel_hid] | |
| z = torch.tanh(self.coord_head(h)) # [B,BR,T,K,2] | |
| mix_logits = self.mix_head(h).squeeze(-1) # [B,BR,T,K] | |
| mix_logits = mix_logits.masked_fill(~keep, float("-inf")) | |
| w = torch.softmax(mix_logits, dim=-1) | |
| w = torch.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0) # extra safety | |
| # ------------------------------------------------------------ | |
| # 5) Sample palette + aggregate | |
| # ------------------------------------------------------------ | |
| batch_pal = self.palette.unsqueeze(0).expand(B * self.BR, -1, -1, -1) # [B*BR,D,Ph,Pw] | |
| grid = z.reshape(B * self.BR, T, K, 2) # [B*BR,T,K,2] | |
| samples = F.grid_sample( | |
| batch_pal, | |
| grid, | |
| mode="bilinear", | |
| padding_mode="border", | |
| align_corners=True, | |
| ) # [B*BR, D, T, K] | |
| samples = samples.view(B, self.BR, D, T, K).permute(0, 1, 3, 4, 2) # [B,BR,T,K,D] | |
| V_out = (samples * w.unsqueeze(-1)).sum(dim=3) # [B,BR,T,D] | |
| y = torch.einsum("nrtd,rdm->nrtm", V_out, self.Wo).mean(dim=1) # [B,T,D] | |
| return y | |
| class LayerNorm(nn.Module): | |
| def __init__(self, ndim: int, bias: bool = True): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(ndim)) | |
| self.use_bias = bias | |
| if bias: | |
| self.bias = nn.Parameter(torch.zeros(ndim)) | |
| else: | |
| self.register_parameter("bias", None) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b = self.bias if self.use_bias else None | |
| return F.layer_norm(x, self.weight.shape, self.weight, b, 1e-5) | |
| class MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.c_fc = nn.Linear( config.n_embd,4* config.n_embd, bias=config.bias) | |
| self.scale = math.pi / math.sqrt(3.0) | |
| self.ln = LayerNorm(config.n_embd*4, bias=config.bias) | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x): | |
| x = self.c_fc(x) | |
| x = x * torch.sigmoid(self.scale * x) | |
| x = self.c_proj(x) | |
| x = self.dropout(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) | |
| self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) | |
| self.attn = VectorizedConstellationAttention(config.n_embd,config.n_head) | |
| self.mlp = MLP(config) | |
| def forward(self,x): | |
| B, T, C = x.shape | |
| x = x + self.attn(self.ln_1(x)) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| @dataclass | |
| class GPTConfig: | |
| block_size: int = 1024 | |
| vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency | |
| n_layer: int = 12 | |
| n_head: int = 12 | |
| n_embd: int = 768 | |
| dropout: float = 0.0 | |
| bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster | |
| class GPT(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.vocab_size is not None | |
| assert config.block_size is not None | |
| self.config = config | |
| # Base noise seed (learned) for map generation | |
| self.transformer = nn.ModuleDict(dict( | |
| wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| drop = nn.Dropout(config.dropout), | |
| h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), | |
| ln_f = LayerNorm(config.n_embd, bias=config.bias), | |
| )) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| # report number of parameters | |
| print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) | |
| def get_num_params(self, non_embedding=True): | |
| """ | |
| Return the number of parameters in the model. | |
| For non-embedding count (default), the position embeddings get subtracted. | |
| The token embeddings would too, except due to the parameter sharing these | |
| params are actually used as weights in the final layer, so we include them. | |
| """ | |
| n_params = sum(p.numel() for p in self.parameters()) | |
| return n_params | |
| def forward(self, idx, targets=None): | |
| device = idx.device | |
| b, T = idx.size() | |
| x = self.transformer.wte(idx) # token | |
| # forward the GPT model itself | |
| for block in self.transformer.h: | |
| x = block(x) | |
| x = self.transformer.ln_f(x) | |
| if targets is not None: | |
| # if we are given some desired targets also calculate the loss | |
| logits = self.lm_head(x) | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) | |
| else: | |
| # inference-time mini-optimization: only forward the lm_head on the very last position | |
| logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim | |
| loss = None | |
| return logits, loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment