Last active
August 22, 2025 07:32
-
-
Save scturtle/717dac754b85944636ff5b09eca117e0 to your computer and use it in GitHub Desktop.
gemma3n
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from safetensors import safe_open | |
| from tokenizers import Tokenizer | |
| @dataclass | |
| class config: | |
| activation_sparsity_pattern = [0.95] * 10 + [0.0] * 25 | |
| altup_active_idx = 0 | |
| altup_coef_clip = 120.0 | |
| altup_correct_scale = True | |
| altup_lr_multiplier = 1.0 | |
| altup_num_inputs = 4 | |
| final_logit_softcapping = 30.0 | |
| head_dim = 256 | |
| hidden_size = 2048 | |
| hidden_size_per_layer_input = 256 | |
| intermediate_size = 16384 | |
| laurel_rank = 64 | |
| layer_types = (["sliding_attention"] * 4 + ["full_attention"]) * 7 | |
| max_position_embeddings = 32768 | |
| num_attention_heads = 8 | |
| num_hidden_layers = 35 | |
| num_key_value_heads = 2 | |
| num_kv_shared_layers = 15 | |
| rms_norm_eps = 1e-06 | |
| rope_local_base_freq = 10000.0 | |
| rope_theta = 1000000.0 | |
| sliding_window = 512 | |
| vocab_size = 262400 | |
| vocab_size_per_layer_input = 262144 | |
| _attn_implementation = "eager" | |
| class Gemma3nRMSNorm(nn.Module): | |
| def __init__(self, dim: int, with_scale: bool = True): | |
| super().__init__() | |
| self.eps = config.rms_norm_eps | |
| self.with_scale = with_scale | |
| if self.with_scale: | |
| self.weight = nn.Parameter(torch.ones(dim, dtype=torch.bfloat16)) | |
| else: | |
| self.register_buffer( | |
| "weight", torch.tensor(1.0, dtype=torch.bfloat16), persistent=False | |
| ) | |
| def _norm(self, x): | |
| return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()) * self.weight.float() | |
| return output.type_as(x) | |
| class Gemma3nTextScaledWordEmbedding(nn.Embedding): | |
| def __init__(self, num_embeddings, embedding_dim, padding_idx, embed_scale): | |
| super().__init__( | |
| num_embeddings, | |
| embedding_dim, | |
| padding_idx, | |
| dtype=torch.bfloat16, | |
| device="meta", | |
| ) | |
| self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) | |
| def forward(self, input_ids): | |
| return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) | |
| class Gemma3nTextLaurelBlock(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear_left = nn.Linear( | |
| config.hidden_size, config.laurel_rank, dtype=torch.bfloat16, bias=False | |
| ) | |
| self.linear_right = nn.Linear( | |
| config.laurel_rank, config.hidden_size, dtype=torch.bfloat16, bias=False | |
| ) | |
| self.post_laurel_norm = Gemma3nRMSNorm(config.hidden_size) | |
| def forward(self, hidden_states): | |
| laurel_hidden_states = self.linear_left(hidden_states) | |
| laurel_hidden_states = self.linear_right(laurel_hidden_states) | |
| normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states) | |
| return hidden_states + normed_laurel_hidden_states | |
| class Gemma3nTextMLP(nn.Module): | |
| def __init__(self, layer_idx: int = 0): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.intermediate_size | |
| self.gate_proj = nn.Linear( | |
| self.hidden_size, | |
| self.intermediate_size, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| device="meta", | |
| ) | |
| self.up_proj = nn.Linear( | |
| self.hidden_size, | |
| self.intermediate_size, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| device="meta", | |
| ) | |
| self.down_proj = nn.Linear( | |
| self.intermediate_size, | |
| self.hidden_size, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| device="meta", | |
| ) | |
| self.activation_sparsity = config.activation_sparsity_pattern[layer_idx] | |
| self.act_fn = nn.GELU("tanh") | |
| def forward(self, hidden_states): | |
| gate_proj = self.gate_proj(hidden_states) | |
| if self.activation_sparsity > 0.0: | |
| gate_proj = self._gaussian_topk(gate_proj) | |
| activations = self.act_fn(gate_proj) | |
| up_proj = self.up_proj(hidden_states) | |
| down_proj = self.down_proj(activations * up_proj) | |
| return down_proj | |
| def _gaussian_topk(self, inputs): | |
| target_sparsity_tensor = torch.tensor( | |
| self.activation_sparsity, dtype=torch.float32, device=inputs.device | |
| ) | |
| normal_dist = torch.distributions.normal.Normal(0, 1) | |
| std_multiplier = normal_dist.icdf(target_sparsity_tensor) | |
| std_multiplier = std_multiplier.type(inputs.dtype) | |
| inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) | |
| inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) | |
| cutoff_x = inputs_mean + inputs_std * std_multiplier | |
| return nn.functional.relu(inputs - cutoff_x) | |
| class Gemma3nTextAltUp(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.correct_output_scale = nn.Parameter( | |
| torch.zeros(config.hidden_size, dtype=torch.bfloat16) | |
| ) | |
| self.correction_coefs = nn.Linear( | |
| config.altup_num_inputs, | |
| config.altup_num_inputs, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| ) | |
| self.prediction_coefs = nn.Linear( | |
| config.altup_num_inputs, | |
| config.altup_num_inputs**2, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| ) | |
| self.modality_router = nn.Linear( | |
| config.hidden_size, | |
| config.altup_num_inputs, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| ) | |
| self.router_norm = Gemma3nRMSNorm(config.hidden_size) | |
| self.register_buffer( | |
| "router_input_scale", | |
| torch.tensor(config.hidden_size**-1.0, dtype=torch.bfloat16), | |
| persistent=False, | |
| ) | |
| def compute_router_modalities(self, x): | |
| router_inputs = self.router_norm(x) * self.router_input_scale | |
| routed = self.modality_router(router_inputs) | |
| return torch.tanh(routed.float()).type_as(x) | |
| def predict(self, hidden_states): | |
| modalities = self.compute_router_modalities( | |
| hidden_states[config.altup_active_idx] | |
| ) | |
| all_coefs = ( | |
| self.prediction_coefs(modalities) | |
| .reshape( | |
| *modalities.shape[:-1], config.altup_num_inputs, config.altup_num_inputs | |
| ) | |
| .permute(0, 1, 3, 2) | |
| ) | |
| predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs) | |
| predictions = predictions.permute(3, 0, 1, 2) | |
| predictions += hidden_states | |
| return predictions.contiguous().type_as(hidden_states) | |
| def correct(self, predictions, activated): | |
| modalities = self.compute_router_modalities(activated) | |
| innovation = activated - predictions[config.altup_active_idx] | |
| innovation = innovation.repeat(config.altup_num_inputs, 1, 1, 1) | |
| if config.altup_coef_clip is not None: | |
| self.correction_coefs.weight.data.clamp_( | |
| -config.altup_coef_clip, config.altup_coef_clip | |
| ) | |
| all_coefs = self.correction_coefs(modalities) + 1.0 | |
| all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1) | |
| corrected = torch.mul(innovation, all_coefs) | |
| corrected += predictions | |
| return corrected.contiguous().type_as(activated) | |
| def scale_corrected_output(self, corrected): | |
| return ( | |
| corrected.type_as(self.correct_output_scale) * self.correct_output_scale | |
| ).type_as(corrected) | |
| def rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=2): | |
| cos = cos.unsqueeze(unsqueeze_dim) | |
| sin = sin.unsqueeze(unsqueeze_dim) | |
| return (x * cos) + (rotate_half(x) * sin) | |
| class Gemma3nTextRotaryEmbedding(nn.Module): | |
| def __init__(self, is_local=False): | |
| super().__init__() | |
| base = config.rope_local_base_freq if is_local else config.rope_theta | |
| inv_freq = 1.0 / ( | |
| base | |
| ** ( | |
| torch.arange(0, config.head_dim, 2, dtype=torch.int64).to( | |
| dtype=torch.float | |
| ) | |
| / config.head_dim | |
| ) | |
| ) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def forward(self, x, position_ids): | |
| inv_freq_expanded = ( | |
| self.inv_freq[None, :, None] | |
| .float() | |
| .expand(position_ids.shape[0], -1, 1) | |
| .to(x.device) | |
| ) | |
| position_ids_expanded = position_ids[:, None, :].float() | |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( | |
| 1, 2 | |
| ) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) | |
| class Gemma3nTextAttention(nn.Module): | |
| def __init__(self, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" | |
| self.sliding_window = config.sliding_window if self.is_sliding else None | |
| self.head_dim = config.head_dim | |
| self.num_key_value_groups = ( | |
| config.num_attention_heads // config.num_key_value_heads | |
| ) | |
| self.is_causal = True | |
| self.q_proj = nn.Linear( | |
| config.hidden_size, | |
| config.num_attention_heads * self.head_dim, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| device="meta", | |
| ) | |
| self.k_proj = nn.Linear( | |
| config.hidden_size, | |
| config.num_key_value_heads * self.head_dim, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| device="meta", | |
| ) | |
| self.v_proj = nn.Linear( | |
| config.hidden_size, | |
| config.num_key_value_heads * self.head_dim, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| device="meta", | |
| ) | |
| self.o_proj = nn.Linear( | |
| config.num_attention_heads * self.head_dim, | |
| config.hidden_size, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| device="meta", | |
| ) | |
| self.q_norm = Gemma3nRMSNorm(dim=config.head_dim) | |
| self.k_norm = Gemma3nRMSNorm(dim=config.head_dim) | |
| self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, with_scale=False) | |
| first_kv_shared_layer_idx = ( | |
| config.num_hidden_layers - config.num_kv_shared_layers | |
| ) | |
| self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx | |
| layer_type = config.layer_types[layer_idx] | |
| self.kv_shared_layer_index = ( | |
| first_kv_shared_layer_idx | |
| - 1 | |
| - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type) | |
| if self.is_kv_shared_layer | |
| else None | |
| ) | |
| def forward( | |
| self, | |
| hidden_states, | |
| mask, | |
| position_embeddings, | |
| cache=None, | |
| layer_cache=None, | |
| ): | |
| bsz, q_len, _ = hidden_states.shape | |
| hidden_shape = (bsz, q_len, -1, config.head_dim) | |
| cos, sin = position_embeddings | |
| query_states = self.q_proj(hidden_states).view(hidden_shape) | |
| query_states = self.q_norm(query_states) | |
| query_states = apply_rotary_pos_emb(query_states, cos, sin) | |
| query_states = query_states.transpose(1, 2) | |
| if self.is_kv_shared_layer: | |
| key_states, value_states = layer_cache[self.kv_shared_layer_index] | |
| else: | |
| key_states = self.k_proj(hidden_states).view(hidden_shape) | |
| key_states = self.k_norm(key_states) | |
| key_states = apply_rotary_pos_emb(key_states, cos, sin) | |
| key_states = key_states.transpose(1, 2) | |
| value_states = self.v_proj(hidden_states).view(hidden_shape) | |
| value_states = self.v_norm(value_states) | |
| value_states = value_states.transpose(1, 2) | |
| if layer_cache is not None: | |
| layer_cache[self.layer_idx] = (key_states, value_states) | |
| if cache is not None: | |
| past_keys, past_values = cache.get( | |
| self.kv_shared_layer_index | |
| if self.is_kv_shared_layer | |
| else self.layer_idx | |
| ) | |
| if past_keys is not None: | |
| key_states = torch.cat([past_keys, key_states], dim=2) | |
| value_states = torch.cat([past_values, value_states], dim=2) | |
| if not self.is_kv_shared_layer: | |
| cache.update(self.layer_idx, (key_states, value_states)) | |
| attn_output = F.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=mask, # or is_causal=q_len > 1, | |
| scale=1.0, | |
| enable_gqa=True, | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.reshape(bsz, q_len, -1) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output | |
| class Gemma3nTextDecoderLayer(nn.Module): | |
| def __init__(self, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.hidden_size = config.hidden_size | |
| self.hidden_size_per_layer_input = config.hidden_size_per_layer_input | |
| self.act_fn = nn.GELU("tanh") | |
| self.mlp = Gemma3nTextMLP(layer_idx=layer_idx) | |
| self.attention_type = config.layer_types[layer_idx] | |
| self.input_layernorm = Gemma3nRMSNorm(self.hidden_size) | |
| self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size) | |
| self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size) | |
| self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size) | |
| self.altup = Gemma3nTextAltUp() | |
| self.laurel = Gemma3nTextLaurelBlock() | |
| self.self_attn = Gemma3nTextAttention(layer_idx) | |
| self.per_layer_input_gate = nn.Linear( | |
| self.hidden_size, | |
| self.hidden_size_per_layer_input, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| ) | |
| self.per_layer_projection = nn.Linear( | |
| self.hidden_size_per_layer_input, | |
| self.hidden_size, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| ) | |
| self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size) | |
| def forward( | |
| self, | |
| hidden_states, | |
| mask_global, | |
| mask_local, | |
| position_embeddings_global, | |
| position_embeddings_local, | |
| per_layer_input, | |
| cache=None, | |
| layer_cache=None, | |
| ): | |
| predictions = self.altup.predict(hidden_states) | |
| active_prediction = predictions[config.altup_active_idx] | |
| active_prediction_normed = self.input_layernorm(active_prediction) | |
| laurel_output = self.laurel(active_prediction_normed) | |
| if self.self_attn.is_sliding: | |
| mask = mask_local | |
| position_embeddings = position_embeddings_local | |
| else: | |
| mask = mask_global | |
| position_embeddings = position_embeddings_global | |
| attn = self.self_attn( | |
| hidden_states=active_prediction_normed, | |
| mask=mask, | |
| position_embeddings=position_embeddings, | |
| cache=cache, | |
| layer_cache=layer_cache, | |
| ) | |
| attn = self.post_attention_layernorm(attn) | |
| attn_gated = active_prediction + attn | |
| attn_laurel = (attn_gated + laurel_output) / (2**0.5) | |
| attn_norm = self.pre_feedforward_layernorm(attn_laurel) | |
| attn_ffw = self.mlp(attn_norm) | |
| attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) | |
| attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm | |
| corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) | |
| first_prediction = corrected_predictions[config.altup_active_idx] | |
| first_prediction_clone = first_prediction.clone() | |
| if config.altup_correct_scale: | |
| first_prediction = self.altup.scale_corrected_output(first_prediction_clone) | |
| first_prediction = self.per_layer_input_gate(first_prediction) | |
| first_prediction = self.act_fn(first_prediction) | |
| first_prediction = torch.multiply(first_prediction, per_layer_input) | |
| first_prediction = self.per_layer_projection(first_prediction) | |
| first_prediction = self.post_per_layer_input_norm(first_prediction) | |
| corrected_predictions[1:] += first_prediction | |
| return corrected_predictions | |
| class Gemma3nTextModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.padding_idx = 0 | |
| self.hidden_size = config.hidden_size | |
| self.hidden_size_per_layer_input = config.hidden_size_per_layer_input | |
| self.embed_tokens = Gemma3nTextScaledWordEmbedding( | |
| config.vocab_size, | |
| config.hidden_size, | |
| self.padding_idx, | |
| embed_scale=config.hidden_size**0.5, | |
| ) | |
| self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( | |
| config.vocab_size_per_layer_input, | |
| config.num_hidden_layers * config.hidden_size_per_layer_input, | |
| self.padding_idx, | |
| embed_scale=config.hidden_size_per_layer_input**0.5, | |
| ) | |
| self.per_layer_model_projection = nn.Linear( | |
| self.hidden_size, | |
| config.num_hidden_layers * config.hidden_size_per_layer_input, | |
| dtype=torch.bfloat16, | |
| bias=False, | |
| ) | |
| self.per_layer_projection_norm = Gemma3nRMSNorm( | |
| config.hidden_size_per_layer_input, | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| Gemma3nTextDecoderLayer(layer_idx) | |
| for layer_idx in range(config.num_hidden_layers) | |
| ] | |
| ) | |
| self.norm = Gemma3nRMSNorm(config.hidden_size) | |
| self.altup_projections = nn.ModuleList( | |
| [ | |
| nn.Linear( | |
| self.hidden_size, self.hidden_size, dtype=torch.bfloat16, bias=False | |
| ) | |
| for _ in range(1, config.altup_num_inputs) | |
| ] | |
| ) | |
| self.altup_unembed_projections = nn.ModuleList( | |
| [ | |
| nn.Linear( | |
| self.hidden_size, self.hidden_size, dtype=torch.bfloat16, bias=False | |
| ) | |
| for _ in range(1, config.altup_num_inputs) | |
| ] | |
| ) | |
| self.register_buffer( | |
| "per_layer_projection_scale", | |
| torch.tensor(self.hidden_size**-0.5), | |
| persistent=False, | |
| ) | |
| self.register_buffer( | |
| "per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False | |
| ) | |
| self.rotary_emb = Gemma3nTextRotaryEmbedding() | |
| self.rotary_emb_local = Gemma3nTextRotaryEmbedding(is_local=True) | |
| def get_per_layer_inputs(self, input_ids: torch.LongTensor): | |
| return self.embed_tokens_per_layer(input_ids).reshape( | |
| *input_ids.shape, | |
| config.num_hidden_layers, | |
| self.hidden_size_per_layer_input, | |
| ) | |
| def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs=None): | |
| per_layer_projection = self.per_layer_model_projection(inputs_embeds) | |
| per_layer_projection *= self.per_layer_projection_scale.type( | |
| inputs_embeds.dtype | |
| ) | |
| per_layer_projection = per_layer_projection.reshape( | |
| *inputs_embeds.shape[:-1], | |
| config.num_hidden_layers, | |
| self.hidden_size_per_layer_input, | |
| ) | |
| per_layer_projection = self.per_layer_projection_norm(per_layer_projection) | |
| if per_layer_inputs is None: | |
| return per_layer_projection | |
| if per_layer_projection.shape != per_layer_inputs.shape: | |
| per_layer_inputs = per_layer_inputs[..., : config.num_hidden_layers, :] | |
| return ( | |
| per_layer_projection + per_layer_inputs | |
| ) * self.per_layer_input_scale.type(inputs_embeds.dtype) | |
| def forward( | |
| self, | |
| input_ids, | |
| cache=None, | |
| ): | |
| _, seq_len = input_ids.shape | |
| if cache.position is None: | |
| cache.position = torch.arange(seq_len, device=input_ids.device) | |
| else: | |
| cache.position = cache.position[-1:] + 1 | |
| position_ids = cache.position.unsqueeze(0) | |
| mask_global, mask_local = self._create_masks( | |
| device=input_ids.device, | |
| pos_start=cache.position[0], | |
| pos_end=cache.position[-1] + 1, | |
| ) | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| per_layer_inputs = self.get_per_layer_inputs(input_ids) | |
| per_layer_inputs = self.project_per_layer_inputs( | |
| inputs_embeds, per_layer_inputs | |
| ) | |
| hidden_states_0 = inputs_embeds | |
| position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids) | |
| position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) | |
| target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 | |
| epsilon_tensor = torch.tensor(torch.finfo().min) | |
| temp_hidden_states = [hidden_states_0] | |
| for i in range(1, config.altup_num_inputs): | |
| altup_proj = self.altup_projections[i - 1](hidden_states_0) | |
| current_hidden_state = altup_proj.type(hidden_states_0.dtype) | |
| new_magnitude = ( | |
| torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 | |
| ) | |
| current_hidden_state = current_hidden_state * ( | |
| target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) | |
| ) | |
| temp_hidden_states.append(current_hidden_state) | |
| hidden_states = torch.stack(temp_hidden_states, dim=0) | |
| layer_cache = {} | |
| for i, decoder_layer in enumerate(self.layers): | |
| per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| mask_global=mask_global, | |
| mask_local=mask_local, | |
| position_embeddings_global=position_embeddings_global, | |
| position_embeddings_local=position_embeddings_local, | |
| per_layer_input=per_layer_input, | |
| cache=cache, | |
| layer_cache=layer_cache, | |
| ) | |
| hidden_states = layer_outputs | |
| if cache: | |
| cache.swap() | |
| target_magnitude = ( | |
| torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5 | |
| ) | |
| temp_hidden_states = [hidden_states[0]] | |
| for i in range(1, config.altup_num_inputs): | |
| altup_unemb_proj = self.altup_unembed_projections[i - 1](hidden_states[i]) | |
| current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) | |
| new_magnitude = ( | |
| torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 | |
| ) | |
| current_hidden_state = current_hidden_state * ( | |
| target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) | |
| ) | |
| temp_hidden_states.append(current_hidden_state) | |
| hidden_states = torch.stack(temp_hidden_states) | |
| hidden_states = torch.mean(hidden_states, dim=0) | |
| hidden_states = self.norm(hidden_states) | |
| return hidden_states | |
| def _create_masks(self, device, pos_start, pos_end): | |
| q_indices = torch.arange(pos_start, pos_end, device=device)[:, None] | |
| k_indices = torch.arange(pos_end, device=device)[None, :] | |
| attend_global = k_indices <= q_indices | |
| is_in_window = k_indices > q_indices - config.sliding_window | |
| attend_local = attend_global & is_in_window | |
| return attend_global[None, None, :, :], attend_local[None, None, :, :] | |
| class Gemma3nForCausalLM(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.language_model = Gemma3nTextModel() | |
| self.lm_head = nn.Linear( | |
| config.hidden_size, | |
| config.vocab_size, | |
| bias=False, | |
| dtype=torch.bfloat16, | |
| device="meta", | |
| ) | |
| self.final_logit_softcapping = config.final_logit_softcapping | |
| def forward( | |
| self, | |
| input_ids, | |
| cache=None, | |
| ): | |
| hidden_states = self.language_model(input_ids=input_ids, cache=cache) | |
| logits = self.lm_head(hidden_states) | |
| logits = logits / self.final_logit_softcapping | |
| logits = torch.tanh(logits) | |
| logits = logits * self.final_logit_softcapping | |
| return logits | |
| class KVCache: | |
| def __init__(self): | |
| from collections import defaultdict | |
| self.cache = defaultdict(lambda: (None, None)) | |
| self.next_cache = {} | |
| self.position = None | |
| def get(self, layer_idx): | |
| return self.cache[layer_idx] | |
| def update(self, layer_idx, value): | |
| self.next_cache[layer_idx] = value | |
| def swap(self): | |
| self.cache, self.next_cache = self.next_cache, {} | |
| def load_model(model: nn.Module, path: str): | |
| state_dict = {} | |
| for file in path.glob("*.safetensors"): | |
| with safe_open(file, "pt", "cpu") as f: | |
| for weight_name in f.keys(): | |
| loaded_tensor = f.get_tensor(weight_name) | |
| if not weight_name.startswith("model.language_model."): | |
| continue | |
| weight_name = weight_name.replace( | |
| "model.language_model.", "language_model." | |
| ) | |
| param = model.get_parameter(weight_name) | |
| assert param.dtype == loaded_tensor.dtype | |
| state_dict[weight_name] = loaded_tensor | |
| state_dict["lm_head.weight"] = state_dict["language_model.embed_tokens.weight"] | |
| model.load_state_dict(state_dict, assign=True) | |
| model.lm_head.weight.data = model.language_model.embed_tokens.weight.data | |
| class GemmaTokenizer: | |
| def __init__(self, tok_file): | |
| self._tok = Tokenizer.from_file(str(tok_file)) | |
| def encode(self, text: str) -> list[int]: | |
| return self._tok.encode(text).ids | |
| def decode(self, ids: list[int]) -> str: | |
| return self._tok.decode(ids, skip_special_tokens=False) | |
| def apply_chat_template(user_text): | |
| return f"<start_of_turn>user\nYou are a helpful assistant.\n\n{user_text}<end_of_turn>\n<start_of_turn>model\n" | |
| model = Gemma3nForCausalLM() | |
| # hf download google/gemma-3n-E4B-it --local-dir gemma-3n-E4B-it | |
| load_model(model, Path("gemma-3n-E4B-it")) | |
| tokenizer = GemmaTokenizer(Path("gemma-3n-E4B-it") / "tokenizer.json") | |
| prompt = "Answer to the Ultimate Question of Life, the Universe, and Everything is" | |
| prompt = apply_chat_template(prompt) | |
| input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) | |
| print(tokenizer.decode(input_ids[0].tolist()), end="", flush=True) | |
| max_new_tokens = 30 | |
| cache = KVCache() | |
| model.eval() | |
| eos_token_id = tokenizer.encode("<end_of_turn>")[-1] | |
| with torch.no_grad(): | |
| next_id = input_ids | |
| for _ in range(max_new_tokens): | |
| outputs = model(next_id, cache) | |
| next_id = torch.argmax(outputs[:, -1, :], dim=-1, keepdim=True) | |
| print(tokenizer.decode(next_id[0].tolist()), end="", flush=True) | |
| if next_id[0].tolist()[0] == eos_token_id: | |
| break | |
| print("\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment