Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active March 6, 2026 10:12
Show Gist options
  • Select an option

  • Save scturtle/1930047017717cd7f56d4a2b726aeb20 to your computer and use it in GitHub Desktop.

Select an option

Save scturtle/1930047017717cd7f56d4a2b726aeb20 to your computer and use it in GitHub Desktop.
qwen 3.5
# pip install torch huggingface_hub safetensors tokenizers transformers
import sys, re, time
from pathlib import Path
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizers import Tokenizer
from safetensors import safe_open
from huggingface_hub import snapshot_download
# ─────────────────────────────── config ───────────────────────────────
REPO_ID = "Qwen/Qwen3.5-0.8B"
MAX_NEW_TOKENS = 500
CFG = {
"vocab_size": 248_320,
"context_length": 262_144,
"emb_dim": 1_024,
"n_heads": 8,
"n_layers": 24,
"hidden_dim": 3_584,
"head_dim": 256,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 10_000_000.0,
"partial_rotary_factor": 0.25,
"rms_norm_eps": 1e-6,
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 16,
"dtype": torch.bfloat16,
"layer_types": (["linear_attention"] * 3 + ["full_attention"]) * 6,
}
# ──────────────────────────────── KVCache ─────────────────────────────
class KVCache:
def __init__(self):
self.kv = defaultdict(lambda: (None, None))
self.lin = defaultdict(lambda: (None, None))
self.position = 0
# ──────────────────────────── linear attn ─────────────────────────────
def l2norm(x, dim=-1, eps=1e-6):
return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
):
initial_dtype = query.dtype
query, key = l2norm(query), l2norm(key)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
B, H, L, Dk = key.shape
Dv = value.shape[-1]
pad = (chunk_size - L % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad))
key = F.pad(key, (0, 0, 0, pad))
value = F.pad(value, (0, 0, 0, pad))
beta = F.pad(beta, (0, pad))
g = F.pad(g, (0, pad))
T = L + pad
query = query / query.shape[-1] ** 0.5
vb = value * beta[..., None]
kb = key * beta[..., None]
query, key, value, kb, vb = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, kb, vb)
]
g = g.reshape(B, H, -1, chunk_size).cumsum(-1)
dm = ((g[..., None] - g[..., None, :]).tril().exp()).tril()
up = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device)
)
A = -(kb @ key.transpose(-1, -2) * dm).masked_fill(up, 0)
for i in range(1, chunk_size):
r = A[..., i, :i].clone()
s = A[..., :i, :i].clone()
A[..., i, :i] = r + (r[..., None] * s).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
value = A @ vb
kcd = A @ (kb * g.exp()[..., None])
state = (
torch.zeros(B, H, Dk, Dv, device=query.device, dtype=value.dtype)
if initial_state is None
else initial_state.to(value)
)
out = torch.zeros_like(value)
up2 = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 1
)
for i in range(T // chunk_size):
qi, ki, vi = query[:, :, i], key[:, :, i], value[:, :, i]
a = (qi @ ki.transpose(-1, -2) * dm[:, :, i]).masked_fill_(up2, 0)
vp = kcd[:, :, i] @ state
vn = vi - vp
out[:, :, i] = (qi * g[:, :, i, :, None].exp()) @ state + a @ vn
state = (
state * g[:, :, i, -1, None, None].exp()
+ (ki * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
-1, -2
)
@ vn
)
out = out.reshape(B, H, -1, Dv)[:, :, :L]
return out.transpose(1, 2).contiguous().to(initial_dtype), state
def torch_recurrent_gated_delta_rule(
query,
key,
value,
g,
beta,
initial_state,
):
initial_dtype = query.dtype
query, key = l2norm(query), l2norm(key)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
B, H, L, Dk = key.shape
Dv = value.shape[-1]
query = query / query.shape[-1] ** 0.5
state = (
torch.zeros(B, H, Dk, Dv, device=query.device, dtype=value.dtype)
if initial_state is None
else initial_state.to(value)
)
out = torch.zeros(B, H, L, Dv, device=query.device, dtype=value.dtype)
for i in range(L):
state = state * g[:, :, i].exp()[..., None, None]
kv_mem = (state * key[:, :, i].unsqueeze(-1)).sum(-2)
delta = (value[:, :, i] - kv_mem) * beta[:, :, i].unsqueeze(-1)
state = state + key[:, :, i].unsqueeze(-1) * delta.unsqueeze(-2)
out[:, :, i] = (state * query[:, :, i].unsqueeze(-1)).sum(-2)
return out.transpose(1, 2).contiguous().to(initial_dtype), state
# ─────────────────────────── model layers ─────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def forward(self, x):
xf = x.float()
xn = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
return (xn * (1.0 + self.weight.float())).to(x.dtype)
class Qwen3_5RMSNormGated(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.variance_epsilon = eps
def forward(self, x, gate=None):
dtype = x.dtype
h = x.float()
h = h * torch.rsqrt(h.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
h = self.weight * h.to(dtype)
return (h * F.silu(gate.float())).to(dtype)
def compute_rope_params(
head_dim, theta_base, context_length, partial_rotary_factor=1.0
):
rot = int(head_dim * partial_rotary_factor)
rot = max(2, rot - rot % 2)
inv = 1.0 / (theta_base ** (torch.arange(0, rot, 2).float() / rot))
pos = torch.arange(context_length).float()
ang = torch.cat([pos[:, None] * inv[None, :]] * 2, dim=1)
return torch.cos(ang), torch.sin(ang)
def apply_rope(x, cos, sin, offset=0):
_, _, L, _ = x.shape
rot = cos.shape[-1]
xr, xp = x[..., :rot], x[..., rot:]
x1, x2 = xr[..., : rot // 2], xr[..., rot // 2 :]
c = cos[offset : offset + L].unsqueeze(0).unsqueeze(0)
s = sin[offset : offset + L].unsqueeze(0).unsqueeze(0)
return torch.cat([torch.cat((-x2, x1), -1) * s + xr * c, xp], -1).to(x.dtype)
# Renamed to match checkpoint: q_proj / k_proj / v_proj / o_proj
class GroupedQueryAttention(nn.Module):
def __init__(self, layer_idx):
super().__init__()
self.layer_idx = layer_idx
nh, nkv, hd, d, dt = (
CFG["n_heads"],
CFG["n_kv_groups"],
CFG["head_dim"],
CFG["emb_dim"],
CFG["dtype"],
)
self.num_heads, self.num_kv_groups, self.head_dim = nh, nkv, hd
self.d_out = nh * hd
# q_proj is 2× head_dim (query + gate) to match checkpoint shape
self.q_proj = nn.Linear(d, self.d_out * 2, bias=False, dtype=dt, device="meta")
self.k_proj = nn.Linear(d, nkv * hd, bias=False, dtype=dt, device="meta")
self.v_proj = nn.Linear(d, nkv * hd, bias=False, dtype=dt, device="meta")
self.o_proj = nn.Linear(self.d_out, d, bias=False, dtype=dt, device="meta")
self.q_norm = RMSNorm(hd)
self.k_norm = RMSNorm(hd)
def forward(self, x, mask, cos, sin, start_pos, cache):
B, L, _ = x.shape
qg = self.q_proj(x).view(B, L, self.num_heads, self.head_dim * 2)
queries = qg[..., : self.head_dim].transpose(1, 2)
gate = qg[..., self.head_dim :].reshape(B, L, self.d_out)
keys = (
self.k_proj(x).view(B, L, self.num_kv_groups, self.head_dim).transpose(1, 2)
)
values = (
self.v_proj(x).view(B, L, self.num_kv_groups, self.head_dim).transpose(1, 2)
)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
queries = apply_rope(queries, cos, sin, offset=start_pos)
keys = apply_rope(keys, cos, sin, offset=start_pos)
prev_k, prev_v = cache.kv[self.layer_idx]
if prev_k is not None:
keys = torch.cat([prev_k, keys], dim=2)
values = torch.cat([prev_v, values], dim=2)
cache.kv[self.layer_idx] = (keys, values)
ctx = F.scaled_dot_product_attention(
queries,
keys,
values,
attn_mask=mask,
scale=self.head_dim**-0.5,
enable_gqa=True,
)
ctx = ctx.transpose(1, 2).reshape(B, L, self.d_out) * torch.sigmoid(gate)
return self.o_proj(ctx)
def causal_conv1d_update(x, state, weight):
# x: [B,C,L]. Updates state in-place
S = state.shape[-1]
buf = torch.cat([state, x.to(weight.dtype)], dim=-1)
state.copy_(buf[:, :, -S:])
# returns SiLU output
out = F.conv1d(buf, weight, padding=0, groups=x.shape[1])
return F.silu(out).to(x.dtype)
class GatedDeltaNet(nn.Module):
def __init__(self):
super().__init__()
Hv, Hk = CFG["linear_num_value_heads"], CFG["linear_num_key_heads"]
Dk, Dv = CFG["linear_key_head_dim"], CFG["linear_value_head_dim"]
ks, d = CFG["linear_conv_kernel_dim"], CFG["emb_dim"]
self.num_v_heads, self.num_k_heads = Hv, Hk
self.head_k_dim, self.head_v_dim = Dk, Dv
self.key_dim = Dk * Hk
self.value_dim = Dv * Hv
self.conv_kernel_size = ks
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = nn.Conv1d(
self.conv_dim,
self.conv_dim,
bias=False,
kernel_size=ks,
groups=self.conv_dim,
padding=ks - 1,
)
self.dt_bias = nn.Parameter(torch.ones(Hv))
self.A_log = nn.Parameter(torch.log(torch.empty(Hv).uniform_(0, 16)))
self.norm = Qwen3_5RMSNormGated(Dv, eps=CFG["rms_norm_eps"])
self.out_proj = nn.Linear(self.value_dim, d, bias=False)
self.in_proj_qkv = nn.Linear(d, self.conv_dim, bias=False)
self.in_proj_z = nn.Linear(d, self.value_dim, bias=False)
self.in_proj_b = nn.Linear(d, Hv, bias=False)
self.in_proj_a = nn.Linear(d, Hv, bias=False)
self.to(dtype=CFG["dtype"])
def forward(self, hidden_states, layer_idx, cache):
B, L, _ = hidden_states.shape
conv_state, rec_state = cache.lin[layer_idx]
qkv_raw = self.in_proj_qkv(hidden_states) # [B, L, conv_dim]
qkv_raw = qkv_raw.transpose(1, 2) # [B, conv_dim, L]
is_decode = L == 1
K = self.conv_kernel_size
if is_decode:
# conv_state: [B, conv_dim, K]
state = torch.cat([conv_state[:, :, 1:], qkv_raw], dim=-1)
mixed_qkv = F.silu(
F.conv1d(state, self.conv1d.weight, padding=0, groups=self.conv_dim)
).transpose(1, 2)
new_conv_state = state
else:
mixed_qkv = F.silu(
F.conv1d(
qkv_raw,
self.conv1d.weight,
padding=K - 1,
groups=self.conv_dim,
)[:, :, :L]
).transpose(1, 2)
new_conv_state = F.pad(qkv_raw, (max(0, K - L), 0))[:, :, -K:]
z = self.in_proj_z(hidden_states).reshape(B, L, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
q, k, v = torch.split(
mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1
)
q = q.reshape(B, L, -1, self.head_k_dim)
k = k.reshape(B, L, -1, self.head_k_dim)
v = v.reshape(B, L, -1, self.head_v_dim)
beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
ratio = self.num_v_heads // self.num_k_heads
if ratio > 1:
q = q.repeat_interleave(ratio, dim=2)
k = k.repeat_interleave(ratio, dim=2)
fn = (
torch_recurrent_gated_delta_rule
if is_decode
else torch_chunk_gated_delta_rule
)
core_out, new_rec_state = fn(q, k, v, g=g, beta=beta, initial_state=rec_state)
cache.lin[layer_idx] = (new_conv_state, new_rec_state)
core_out = self.norm(
core_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim)
).reshape(B, L, -1)
return self.out_proj(core_out)
class FeedForward(nn.Module):
def __init__(self):
super().__init__()
d, h, dt = CFG["emb_dim"], CFG["hidden_dim"], CFG["dtype"]
self.gate_proj = nn.Linear(d, h, bias=False, dtype=dt, device="meta")
self.up_proj = nn.Linear(d, h, bias=False, dtype=dt, device="meta")
self.down_proj = nn.Linear(h, d, bias=False, dtype=dt, device="meta")
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_type, layer_idx):
super().__init__()
self.layer_type = layer_type
self.layer_idx = layer_idx
if layer_type == "full_attention":
self.self_attn = GroupedQueryAttention(layer_idx)
else:
self.linear_attn = GatedDeltaNet()
self.mlp = FeedForward()
eps = CFG["rms_norm_eps"]
self.input_layernorm = RMSNorm(CFG["emb_dim"], eps)
self.post_attention_layernorm = RMSNorm(CFG["emb_dim"], eps)
def forward(self, x, mask, cos, sin, start_pos, cache):
h = self.input_layernorm(x)
if self.layer_type == "full_attention":
h = self.self_attn(h, mask, cos, sin, start_pos, cache)
else:
h = self.linear_attn(h, self.layer_idx, cache)
x = x + h
x = x + self.mlp(self.post_attention_layernorm(x))
return x
class Qwen3_5Model(nn.Module):
def __init__(self):
super().__init__()
self.embed_tokens = nn.Embedding(
CFG["vocab_size"], CFG["emb_dim"], dtype=CFG["dtype"], device="meta"
)
self.layers = nn.ModuleList(
[TransformerBlock(lt, i) for i, lt in enumerate(CFG["layer_types"])]
)
self.norm = RMSNorm(CFG["emb_dim"], CFG["rms_norm_eps"])
self.lm_head = nn.Linear(
CFG["emb_dim"],
CFG["vocab_size"],
bias=False,
dtype=CFG["dtype"],
device="meta",
)
cos, sin = compute_rope_params(
head_dim=CFG["head_dim"],
theta_base=CFG["rope_base"],
context_length=CFG["context_length"],
partial_rotary_factor=CFG["partial_rotary_factor"],
)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
def _create_mask(self, device, start_pos, seq_len):
q_idx = torch.arange(start_pos, start_pos + seq_len, device=device)[:, None]
k_idx = torch.arange(start_pos + seq_len, device=device)[None, :]
return (k_idx <= q_idx)[None, None, :, :] # True = attend
def forward(self, input_ids, cache):
_, seq_len = input_ids.shape
x = self.embed_tokens(input_ids)
start_pos = cache.position
cache.position += seq_len
mask = self._create_mask(x.device, start_pos, seq_len)
for block in self.layers:
x = block(x, mask, self.cos, self.sin, start_pos, cache)
return self.lm_head(self.norm(x))
def load_weights_into_model(model, model_dir: Path):
state_dict = {}
def strip_prefix(k):
for pfx in ("model.language_model.", "model."):
if k.startswith(pfx):
return k[len(pfx) :]
return k
for f in sorted(model_dir.glob("*.safetensors")):
with safe_open(f, framework="pt", device="cpu") as sf:
for k in sf.keys():
key = strip_prefix(k)
try:
param = model.get_parameter(key)
except AttributeError:
continue # skip visual encoder / unused keys
t = sf.get_tensor(k)
assert param.shape == t.shape, f"shape mismatch: {key}"
state_dict[key] = t
# share embed_tokens ↔ lm_head if lm_head absent
state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"]
model.load_state_dict(state_dict, assign=True)
class Qwen3_5Tokenizer:
_SPECIALS = [
"<|endoftext|>",
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>",
"<think>",
"</think>",
]
_SPLIT_RE = re.compile(r"(<\|[^>]+?\|>|<think>|</think>)")
def __init__(self, tok_file, thinking=True):
self._tok = Tokenizer.from_file(str(tok_file))
self._sid = {
t: self._tok.token_to_id(t)
for t in self._SPECIALS
if self._tok.token_to_id(t) is not None
}
self.eos_token_id = self._sid.get("<|im_end|>", self._sid["<|endoftext|>"])
self.thinking = thinking
def encode(self, text):
wrapped = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
wrapped += "<think>\n" if self.thinking else "<think>\n\n</think>\n\n"
ids = []
for part in filter(None, self._SPLIT_RE.split(wrapped)):
(
ids.append(self._sid[part])
if part in self._sid
else ids.extend(self._tok.encode(part).ids)
)
return ids
def decode(self, ids):
return self._tok.decode(ids, skip_special_tokens=False)
def generate(model, token_ids, max_new_tokens, eos_id, device):
model.eval()
cache = KVCache()
ids = token_ids.to(device)
with torch.no_grad():
# prefill
logits = model(ids, cache)
next_tok = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
yield next_tok
# decode
for _ in range(max_new_tokens - 1):
if eos_id is not None and torch.all(next_tok == eos_id):
break
logits = model(next_tok, cache)
next_tok = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
yield next_tok
def main():
prompt = (
" ".join(sys.argv[1:])
if len(sys.argv) > 1
else "Give me a short introduction to large language models."
)
print("Downloading model...")
local_dir = Path(REPO_ID).parts[-1]
repo_dir = Path(snapshot_download(repo_id=REPO_ID, local_dir=str(local_dir)))
print("Loading model...")
model = Qwen3_5Model()
load_weights_into_model(model, repo_dir)
device = torch.device("cpu")
model.to(device)
tokenizer = Qwen3_5Tokenizer(repo_dir / "tokenizer.json", thinking=True)
print(f"\nPrompt: {prompt}\n" + "─" * 60)
input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
for tok in generate(
model, input_ids, MAX_NEW_TOKENS, tokenizer.eos_token_id, device
):
print(tokenizer.decode(tok.squeeze(0).tolist()), end="", flush=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment