Last active
March 6, 2026 10:12
-
-
Save scturtle/1930047017717cd7f56d4a2b726aeb20 to your computer and use it in GitHub Desktop.
qwen 3.5
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # pip install torch huggingface_hub safetensors tokenizers transformers | |
| import sys, re, time | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tokenizers import Tokenizer | |
| from safetensors import safe_open | |
| from huggingface_hub import snapshot_download | |
| # ─────────────────────────────── config ─────────────────────────────── | |
| REPO_ID = "Qwen/Qwen3.5-0.8B" | |
| MAX_NEW_TOKENS = 500 | |
| CFG = { | |
| "vocab_size": 248_320, | |
| "context_length": 262_144, | |
| "emb_dim": 1_024, | |
| "n_heads": 8, | |
| "n_layers": 24, | |
| "hidden_dim": 3_584, | |
| "head_dim": 256, | |
| "qk_norm": True, | |
| "n_kv_groups": 2, | |
| "rope_base": 10_000_000.0, | |
| "partial_rotary_factor": 0.25, | |
| "rms_norm_eps": 1e-6, | |
| "linear_conv_kernel_dim": 4, | |
| "linear_key_head_dim": 128, | |
| "linear_value_head_dim": 128, | |
| "linear_num_key_heads": 16, | |
| "linear_num_value_heads": 16, | |
| "dtype": torch.bfloat16, | |
| "layer_types": (["linear_attention"] * 3 + ["full_attention"]) * 6, | |
| } | |
| # ──────────────────────────────── KVCache ───────────────────────────── | |
| class KVCache: | |
| def __init__(self): | |
| self.kv = defaultdict(lambda: (None, None)) | |
| self.lin = defaultdict(lambda: (None, None)) | |
| self.position = 0 | |
| # ──────────────────────────── linear attn ───────────────────────────── | |
| def l2norm(x, dim=-1, eps=1e-6): | |
| return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) | |
| def torch_chunk_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g, | |
| beta, | |
| chunk_size=64, | |
| initial_state=None, | |
| ): | |
| initial_dtype = query.dtype | |
| query, key = l2norm(query), l2norm(key) | |
| query, key, value, beta, g = [ | |
| x.transpose(1, 2).contiguous().to(torch.float32) | |
| for x in (query, key, value, beta, g) | |
| ] | |
| B, H, L, Dk = key.shape | |
| Dv = value.shape[-1] | |
| pad = (chunk_size - L % chunk_size) % chunk_size | |
| query = F.pad(query, (0, 0, 0, pad)) | |
| key = F.pad(key, (0, 0, 0, pad)) | |
| value = F.pad(value, (0, 0, 0, pad)) | |
| beta = F.pad(beta, (0, pad)) | |
| g = F.pad(g, (0, pad)) | |
| T = L + pad | |
| query = query / query.shape[-1] ** 0.5 | |
| vb = value * beta[..., None] | |
| kb = key * beta[..., None] | |
| query, key, value, kb, vb = [ | |
| x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) | |
| for x in (query, key, value, kb, vb) | |
| ] | |
| g = g.reshape(B, H, -1, chunk_size).cumsum(-1) | |
| dm = ((g[..., None] - g[..., None, :]).tril().exp()).tril() | |
| up = torch.triu( | |
| torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device) | |
| ) | |
| A = -(kb @ key.transpose(-1, -2) * dm).masked_fill(up, 0) | |
| for i in range(1, chunk_size): | |
| r = A[..., i, :i].clone() | |
| s = A[..., :i, :i].clone() | |
| A[..., i, :i] = r + (r[..., None] * s).sum(-2) | |
| A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device) | |
| value = A @ vb | |
| kcd = A @ (kb * g.exp()[..., None]) | |
| state = ( | |
| torch.zeros(B, H, Dk, Dv, device=query.device, dtype=value.dtype) | |
| if initial_state is None | |
| else initial_state.to(value) | |
| ) | |
| out = torch.zeros_like(value) | |
| up2 = torch.triu( | |
| torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 1 | |
| ) | |
| for i in range(T // chunk_size): | |
| qi, ki, vi = query[:, :, i], key[:, :, i], value[:, :, i] | |
| a = (qi @ ki.transpose(-1, -2) * dm[:, :, i]).masked_fill_(up2, 0) | |
| vp = kcd[:, :, i] @ state | |
| vn = vi - vp | |
| out[:, :, i] = (qi * g[:, :, i, :, None].exp()) @ state + a @ vn | |
| state = ( | |
| state * g[:, :, i, -1, None, None].exp() | |
| + (ki * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( | |
| -1, -2 | |
| ) | |
| @ vn | |
| ) | |
| out = out.reshape(B, H, -1, Dv)[:, :, :L] | |
| return out.transpose(1, 2).contiguous().to(initial_dtype), state | |
| def torch_recurrent_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g, | |
| beta, | |
| initial_state, | |
| ): | |
| initial_dtype = query.dtype | |
| query, key = l2norm(query), l2norm(key) | |
| query, key, value, beta, g = [ | |
| x.transpose(1, 2).contiguous().to(torch.float32) | |
| for x in (query, key, value, beta, g) | |
| ] | |
| B, H, L, Dk = key.shape | |
| Dv = value.shape[-1] | |
| query = query / query.shape[-1] ** 0.5 | |
| state = ( | |
| torch.zeros(B, H, Dk, Dv, device=query.device, dtype=value.dtype) | |
| if initial_state is None | |
| else initial_state.to(value) | |
| ) | |
| out = torch.zeros(B, H, L, Dv, device=query.device, dtype=value.dtype) | |
| for i in range(L): | |
| state = state * g[:, :, i].exp()[..., None, None] | |
| kv_mem = (state * key[:, :, i].unsqueeze(-1)).sum(-2) | |
| delta = (value[:, :, i] - kv_mem) * beta[:, :, i].unsqueeze(-1) | |
| state = state + key[:, :, i].unsqueeze(-1) * delta.unsqueeze(-2) | |
| out[:, :, i] = (state * query[:, :, i].unsqueeze(-1)).sum(-2) | |
| return out.transpose(1, 2).contiguous().to(initial_dtype), state | |
| # ─────────────────────────── model layers ───────────────────────────── | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.zeros(dim)) | |
| def forward(self, x): | |
| xf = x.float() | |
| xn = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) | |
| return (xn * (1.0 + self.weight.float())).to(x.dtype) | |
| class Qwen3_5RMSNormGated(nn.Module): | |
| def __init__(self, dim, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.variance_epsilon = eps | |
| def forward(self, x, gate=None): | |
| dtype = x.dtype | |
| h = x.float() | |
| h = h * torch.rsqrt(h.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) | |
| h = self.weight * h.to(dtype) | |
| return (h * F.silu(gate.float())).to(dtype) | |
| def compute_rope_params( | |
| head_dim, theta_base, context_length, partial_rotary_factor=1.0 | |
| ): | |
| rot = int(head_dim * partial_rotary_factor) | |
| rot = max(2, rot - rot % 2) | |
| inv = 1.0 / (theta_base ** (torch.arange(0, rot, 2).float() / rot)) | |
| pos = torch.arange(context_length).float() | |
| ang = torch.cat([pos[:, None] * inv[None, :]] * 2, dim=1) | |
| return torch.cos(ang), torch.sin(ang) | |
| def apply_rope(x, cos, sin, offset=0): | |
| _, _, L, _ = x.shape | |
| rot = cos.shape[-1] | |
| xr, xp = x[..., :rot], x[..., rot:] | |
| x1, x2 = xr[..., : rot // 2], xr[..., rot // 2 :] | |
| c = cos[offset : offset + L].unsqueeze(0).unsqueeze(0) | |
| s = sin[offset : offset + L].unsqueeze(0).unsqueeze(0) | |
| return torch.cat([torch.cat((-x2, x1), -1) * s + xr * c, xp], -1).to(x.dtype) | |
| # Renamed to match checkpoint: q_proj / k_proj / v_proj / o_proj | |
| class GroupedQueryAttention(nn.Module): | |
| def __init__(self, layer_idx): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| nh, nkv, hd, d, dt = ( | |
| CFG["n_heads"], | |
| CFG["n_kv_groups"], | |
| CFG["head_dim"], | |
| CFG["emb_dim"], | |
| CFG["dtype"], | |
| ) | |
| self.num_heads, self.num_kv_groups, self.head_dim = nh, nkv, hd | |
| self.d_out = nh * hd | |
| # q_proj is 2× head_dim (query + gate) to match checkpoint shape | |
| self.q_proj = nn.Linear(d, self.d_out * 2, bias=False, dtype=dt, device="meta") | |
| self.k_proj = nn.Linear(d, nkv * hd, bias=False, dtype=dt, device="meta") | |
| self.v_proj = nn.Linear(d, nkv * hd, bias=False, dtype=dt, device="meta") | |
| self.o_proj = nn.Linear(self.d_out, d, bias=False, dtype=dt, device="meta") | |
| self.q_norm = RMSNorm(hd) | |
| self.k_norm = RMSNorm(hd) | |
| def forward(self, x, mask, cos, sin, start_pos, cache): | |
| B, L, _ = x.shape | |
| qg = self.q_proj(x).view(B, L, self.num_heads, self.head_dim * 2) | |
| queries = qg[..., : self.head_dim].transpose(1, 2) | |
| gate = qg[..., self.head_dim :].reshape(B, L, self.d_out) | |
| keys = ( | |
| self.k_proj(x).view(B, L, self.num_kv_groups, self.head_dim).transpose(1, 2) | |
| ) | |
| values = ( | |
| self.v_proj(x).view(B, L, self.num_kv_groups, self.head_dim).transpose(1, 2) | |
| ) | |
| queries = self.q_norm(queries) | |
| keys = self.k_norm(keys) | |
| queries = apply_rope(queries, cos, sin, offset=start_pos) | |
| keys = apply_rope(keys, cos, sin, offset=start_pos) | |
| prev_k, prev_v = cache.kv[self.layer_idx] | |
| if prev_k is not None: | |
| keys = torch.cat([prev_k, keys], dim=2) | |
| values = torch.cat([prev_v, values], dim=2) | |
| cache.kv[self.layer_idx] = (keys, values) | |
| ctx = F.scaled_dot_product_attention( | |
| queries, | |
| keys, | |
| values, | |
| attn_mask=mask, | |
| scale=self.head_dim**-0.5, | |
| enable_gqa=True, | |
| ) | |
| ctx = ctx.transpose(1, 2).reshape(B, L, self.d_out) * torch.sigmoid(gate) | |
| return self.o_proj(ctx) | |
| def causal_conv1d_update(x, state, weight): | |
| # x: [B,C,L]. Updates state in-place | |
| S = state.shape[-1] | |
| buf = torch.cat([state, x.to(weight.dtype)], dim=-1) | |
| state.copy_(buf[:, :, -S:]) | |
| # returns SiLU output | |
| out = F.conv1d(buf, weight, padding=0, groups=x.shape[1]) | |
| return F.silu(out).to(x.dtype) | |
| class GatedDeltaNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| Hv, Hk = CFG["linear_num_value_heads"], CFG["linear_num_key_heads"] | |
| Dk, Dv = CFG["linear_key_head_dim"], CFG["linear_value_head_dim"] | |
| ks, d = CFG["linear_conv_kernel_dim"], CFG["emb_dim"] | |
| self.num_v_heads, self.num_k_heads = Hv, Hk | |
| self.head_k_dim, self.head_v_dim = Dk, Dv | |
| self.key_dim = Dk * Hk | |
| self.value_dim = Dv * Hv | |
| self.conv_kernel_size = ks | |
| self.conv_dim = self.key_dim * 2 + self.value_dim | |
| self.conv1d = nn.Conv1d( | |
| self.conv_dim, | |
| self.conv_dim, | |
| bias=False, | |
| kernel_size=ks, | |
| groups=self.conv_dim, | |
| padding=ks - 1, | |
| ) | |
| self.dt_bias = nn.Parameter(torch.ones(Hv)) | |
| self.A_log = nn.Parameter(torch.log(torch.empty(Hv).uniform_(0, 16))) | |
| self.norm = Qwen3_5RMSNormGated(Dv, eps=CFG["rms_norm_eps"]) | |
| self.out_proj = nn.Linear(self.value_dim, d, bias=False) | |
| self.in_proj_qkv = nn.Linear(d, self.conv_dim, bias=False) | |
| self.in_proj_z = nn.Linear(d, self.value_dim, bias=False) | |
| self.in_proj_b = nn.Linear(d, Hv, bias=False) | |
| self.in_proj_a = nn.Linear(d, Hv, bias=False) | |
| self.to(dtype=CFG["dtype"]) | |
| def forward(self, hidden_states, layer_idx, cache): | |
| B, L, _ = hidden_states.shape | |
| conv_state, rec_state = cache.lin[layer_idx] | |
| qkv_raw = self.in_proj_qkv(hidden_states) # [B, L, conv_dim] | |
| qkv_raw = qkv_raw.transpose(1, 2) # [B, conv_dim, L] | |
| is_decode = L == 1 | |
| K = self.conv_kernel_size | |
| if is_decode: | |
| # conv_state: [B, conv_dim, K] | |
| state = torch.cat([conv_state[:, :, 1:], qkv_raw], dim=-1) | |
| mixed_qkv = F.silu( | |
| F.conv1d(state, self.conv1d.weight, padding=0, groups=self.conv_dim) | |
| ).transpose(1, 2) | |
| new_conv_state = state | |
| else: | |
| mixed_qkv = F.silu( | |
| F.conv1d( | |
| qkv_raw, | |
| self.conv1d.weight, | |
| padding=K - 1, | |
| groups=self.conv_dim, | |
| )[:, :, :L] | |
| ).transpose(1, 2) | |
| new_conv_state = F.pad(qkv_raw, (max(0, K - L), 0))[:, :, -K:] | |
| z = self.in_proj_z(hidden_states).reshape(B, L, -1, self.head_v_dim) | |
| b = self.in_proj_b(hidden_states) | |
| a = self.in_proj_a(hidden_states) | |
| q, k, v = torch.split( | |
| mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1 | |
| ) | |
| q = q.reshape(B, L, -1, self.head_k_dim) | |
| k = k.reshape(B, L, -1, self.head_k_dim) | |
| v = v.reshape(B, L, -1, self.head_v_dim) | |
| beta = b.sigmoid() | |
| g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) | |
| ratio = self.num_v_heads // self.num_k_heads | |
| if ratio > 1: | |
| q = q.repeat_interleave(ratio, dim=2) | |
| k = k.repeat_interleave(ratio, dim=2) | |
| fn = ( | |
| torch_recurrent_gated_delta_rule | |
| if is_decode | |
| else torch_chunk_gated_delta_rule | |
| ) | |
| core_out, new_rec_state = fn(q, k, v, g=g, beta=beta, initial_state=rec_state) | |
| cache.lin[layer_idx] = (new_conv_state, new_rec_state) | |
| core_out = self.norm( | |
| core_out.reshape(-1, self.head_v_dim), z.reshape(-1, self.head_v_dim) | |
| ).reshape(B, L, -1) | |
| return self.out_proj(core_out) | |
| class FeedForward(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| d, h, dt = CFG["emb_dim"], CFG["hidden_dim"], CFG["dtype"] | |
| self.gate_proj = nn.Linear(d, h, bias=False, dtype=dt, device="meta") | |
| self.up_proj = nn.Linear(d, h, bias=False, dtype=dt, device="meta") | |
| self.down_proj = nn.Linear(h, d, bias=False, dtype=dt, device="meta") | |
| def forward(self, x): | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, layer_type, layer_idx): | |
| super().__init__() | |
| self.layer_type = layer_type | |
| self.layer_idx = layer_idx | |
| if layer_type == "full_attention": | |
| self.self_attn = GroupedQueryAttention(layer_idx) | |
| else: | |
| self.linear_attn = GatedDeltaNet() | |
| self.mlp = FeedForward() | |
| eps = CFG["rms_norm_eps"] | |
| self.input_layernorm = RMSNorm(CFG["emb_dim"], eps) | |
| self.post_attention_layernorm = RMSNorm(CFG["emb_dim"], eps) | |
| def forward(self, x, mask, cos, sin, start_pos, cache): | |
| h = self.input_layernorm(x) | |
| if self.layer_type == "full_attention": | |
| h = self.self_attn(h, mask, cos, sin, start_pos, cache) | |
| else: | |
| h = self.linear_attn(h, self.layer_idx, cache) | |
| x = x + h | |
| x = x + self.mlp(self.post_attention_layernorm(x)) | |
| return x | |
| class Qwen3_5Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.embed_tokens = nn.Embedding( | |
| CFG["vocab_size"], CFG["emb_dim"], dtype=CFG["dtype"], device="meta" | |
| ) | |
| self.layers = nn.ModuleList( | |
| [TransformerBlock(lt, i) for i, lt in enumerate(CFG["layer_types"])] | |
| ) | |
| self.norm = RMSNorm(CFG["emb_dim"], CFG["rms_norm_eps"]) | |
| self.lm_head = nn.Linear( | |
| CFG["emb_dim"], | |
| CFG["vocab_size"], | |
| bias=False, | |
| dtype=CFG["dtype"], | |
| device="meta", | |
| ) | |
| cos, sin = compute_rope_params( | |
| head_dim=CFG["head_dim"], | |
| theta_base=CFG["rope_base"], | |
| context_length=CFG["context_length"], | |
| partial_rotary_factor=CFG["partial_rotary_factor"], | |
| ) | |
| self.register_buffer("cos", cos, persistent=False) | |
| self.register_buffer("sin", sin, persistent=False) | |
| def _create_mask(self, device, start_pos, seq_len): | |
| q_idx = torch.arange(start_pos, start_pos + seq_len, device=device)[:, None] | |
| k_idx = torch.arange(start_pos + seq_len, device=device)[None, :] | |
| return (k_idx <= q_idx)[None, None, :, :] # True = attend | |
| def forward(self, input_ids, cache): | |
| _, seq_len = input_ids.shape | |
| x = self.embed_tokens(input_ids) | |
| start_pos = cache.position | |
| cache.position += seq_len | |
| mask = self._create_mask(x.device, start_pos, seq_len) | |
| for block in self.layers: | |
| x = block(x, mask, self.cos, self.sin, start_pos, cache) | |
| return self.lm_head(self.norm(x)) | |
| def load_weights_into_model(model, model_dir: Path): | |
| state_dict = {} | |
| def strip_prefix(k): | |
| for pfx in ("model.language_model.", "model."): | |
| if k.startswith(pfx): | |
| return k[len(pfx) :] | |
| return k | |
| for f in sorted(model_dir.glob("*.safetensors")): | |
| with safe_open(f, framework="pt", device="cpu") as sf: | |
| for k in sf.keys(): | |
| key = strip_prefix(k) | |
| try: | |
| param = model.get_parameter(key) | |
| except AttributeError: | |
| continue # skip visual encoder / unused keys | |
| t = sf.get_tensor(k) | |
| assert param.shape == t.shape, f"shape mismatch: {key}" | |
| state_dict[key] = t | |
| # share embed_tokens ↔ lm_head if lm_head absent | |
| state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"] | |
| model.load_state_dict(state_dict, assign=True) | |
| class Qwen3_5Tokenizer: | |
| _SPECIALS = [ | |
| "<|endoftext|>", | |
| "<|im_start|>", | |
| "<|im_end|>", | |
| "<|object_ref_start|>", | |
| "<|object_ref_end|>", | |
| "<|box_start|>", | |
| "<|box_end|>", | |
| "<|quad_start|>", | |
| "<|quad_end|>", | |
| "<|vision_start|>", | |
| "<|vision_end|>", | |
| "<|vision_pad|>", | |
| "<|image_pad|>", | |
| "<|video_pad|>", | |
| "<think>", | |
| "</think>", | |
| ] | |
| _SPLIT_RE = re.compile(r"(<\|[^>]+?\|>|<think>|</think>)") | |
| def __init__(self, tok_file, thinking=True): | |
| self._tok = Tokenizer.from_file(str(tok_file)) | |
| self._sid = { | |
| t: self._tok.token_to_id(t) | |
| for t in self._SPECIALS | |
| if self._tok.token_to_id(t) is not None | |
| } | |
| self.eos_token_id = self._sid.get("<|im_end|>", self._sid["<|endoftext|>"]) | |
| self.thinking = thinking | |
| def encode(self, text): | |
| wrapped = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n" | |
| wrapped += "<think>\n" if self.thinking else "<think>\n\n</think>\n\n" | |
| ids = [] | |
| for part in filter(None, self._SPLIT_RE.split(wrapped)): | |
| ( | |
| ids.append(self._sid[part]) | |
| if part in self._sid | |
| else ids.extend(self._tok.encode(part).ids) | |
| ) | |
| return ids | |
| def decode(self, ids): | |
| return self._tok.decode(ids, skip_special_tokens=False) | |
| def generate(model, token_ids, max_new_tokens, eos_id, device): | |
| model.eval() | |
| cache = KVCache() | |
| ids = token_ids.to(device) | |
| with torch.no_grad(): | |
| # prefill | |
| logits = model(ids, cache) | |
| next_tok = torch.argmax(logits[:, -1], dim=-1, keepdim=True) | |
| yield next_tok | |
| # decode | |
| for _ in range(max_new_tokens - 1): | |
| if eos_id is not None and torch.all(next_tok == eos_id): | |
| break | |
| logits = model(next_tok, cache) | |
| next_tok = torch.argmax(logits[:, -1], dim=-1, keepdim=True) | |
| yield next_tok | |
| def main(): | |
| prompt = ( | |
| " ".join(sys.argv[1:]) | |
| if len(sys.argv) > 1 | |
| else "Give me a short introduction to large language models." | |
| ) | |
| print("Downloading model...") | |
| local_dir = Path(REPO_ID).parts[-1] | |
| repo_dir = Path(snapshot_download(repo_id=REPO_ID, local_dir=str(local_dir))) | |
| print("Loading model...") | |
| model = Qwen3_5Model() | |
| load_weights_into_model(model, repo_dir) | |
| device = torch.device("cpu") | |
| model.to(device) | |
| tokenizer = Qwen3_5Tokenizer(repo_dir / "tokenizer.json", thinking=True) | |
| print(f"\nPrompt: {prompt}\n" + "─" * 60) | |
| input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) | |
| for tok in generate( | |
| model, input_ids, MAX_NEW_TOKENS, tokenizer.eos_token_id, device | |
| ): | |
| print(tokenizer.decode(tok.squeeze(0).tolist()), end="", flush=True) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment