Last active
July 27, 2025 21:17
-
-
Save dhbrojas/a75666524003f3076c65ed619e28c6e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Callable, Protocol | |
| import torch | |
| from torch import Tensor | |
| from torch.nn import Linear, Module | |
| from torch.nn.functional import silu | |
| def compute_frequencies( | |
| *, | |
| dim: int, | |
| theta: int, | |
| device: torch.device | None = None, | |
| dtype: torch.dtype | None = torch.float32, | |
| ) -> Tensor: | |
| """ | |
| A mapping `() -> (dim // 2)` | |
| """ | |
| return 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype, device=device) / dim)) | |
| def expand_frequencies( | |
| frequencies: Tensor, *, max_sequence_length: int, dtype: torch.dtype = torch.float32 | |
| ) -> Tensor: | |
| """ | |
| A mapping `(dim // 2) -> (len, dim // 2)` that expand the frequencies up to the | |
| maximum sequence length of the model. | |
| """ | |
| return torch.outer(torch.arange(max_sequence_length, device=frequencies.device), frequencies).to(dtype) | |
| def rescale_frequencies( | |
| frequencies: Tensor, | |
| *, | |
| original_max_sequence_length: int, | |
| factor: float, | |
| low_frequency_factor: float, | |
| high_frequency_factor: float, | |
| ) -> Tensor: | |
| """ | |
| A mapping `(len, dim // 2) -> (len, dim // 2)` that rescales the frequencies for models | |
| trained on shorter sequences. | |
| """ | |
| low_frequency_wavelength = original_max_sequence_length / low_frequency_factor | |
| high_frequency_wavelength = original_max_sequence_length / high_frequency_factor | |
| wavelength = 2 * math.pi / frequencies | |
| frequencies = torch.where( | |
| wavelength > low_frequency_wavelength, frequencies / factor, frequencies | |
| ) | |
| smooth_factor = ( | |
| original_max_sequence_length / wavelength - low_frequency_factor | |
| ) / (high_frequency_factor - low_frequency_factor) | |
| smoothed_frequencies = ( | |
| 1 - smooth_factor | |
| ) * frequencies / factor + smooth_factor * frequencies | |
| is_medium_frequency = ~(wavelength < high_frequency_wavelength) * ~( | |
| wavelength > low_frequency_wavelength | |
| ) | |
| frequencies = torch.where(is_medium_frequency, smoothed_frequencies, frequencies) | |
| return frequencies | |
| def rotate(x: Tensor, frequencies: Tensor) -> Tensor: | |
| """ | |
| A mapping `(*, length), (length, dim // 2) -> (*, length, dim)` that | |
| rotates features of based on their position in the sequence. | |
| """ | |
| assert x.dim() >= 2, f"{x.dim()} >= 2" | |
| x_length, x_dim = x.shape[-2:] | |
| assert frequencies.dim() == 2, f"{frequencies.dim()} == 2" | |
| f_length, f_half_dim = frequencies.shape | |
| assert f_length == x_length, f"{f_length} == {x_length}" | |
| assert x_dim // 2 == f_half_dim, f"{x_dim} // 2 == {f_half_dim}" | |
| def rotate_half(x: Tensor) -> Tensor: | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| cos, sin = ( | |
| torch.cos(frequencies).unsqueeze(0).unsqueeze(1), | |
| torch.sin(frequencies).unsqueeze(0).unsqueeze(1), | |
| ) | |
| cos = torch.cat((cos, cos), dim=-1) | |
| sin = torch.cat((sin, sin), dim=-1) | |
| x = (x * cos) + (rotate_half(x) * sin) | |
| return x | |
| class RotaryPositionalEncoding(Module): | |
| def __init__(self, *, frequencies: Tensor): | |
| super().__init__() | |
| self.frequencies: Tensor | |
| self.register_buffer( | |
| "frequencies", | |
| frequencies, | |
| persistent=False, | |
| ) | |
| def forward( | |
| self, q: Tensor, k: Tensor, *, offset: int, length: int | |
| ) -> Tuple[Tensor, Tensor]: | |
| with torch.autocast(device_type=q.device.type, enabled=False): | |
| return ( | |
| rotate(q, self.frequencies[offset : offset + length]), | |
| rotate(k, self.frequencies[offset : offset + length]), | |
| ) | |
| @staticmethod | |
| def new( | |
| *, | |
| head_dim: int, | |
| theta: int, | |
| max_sequence_length: int, | |
| device: torch.device | None = None, | |
| ) -> "RotaryPositionalEncoding": | |
| return RotaryPositionalEncoding( | |
| frequencies=expand_frequencies( | |
| compute_frequencies(dim=head_dim, theta=theta, device=device), | |
| max_sequence_length=max_sequence_length, | |
| ), | |
| ) | |
| @staticmethod | |
| def new_rescaled( | |
| *, | |
| head_dim: int, | |
| theta: int, | |
| max_sequence_length: int, | |
| original_max_sequence_length: int, | |
| factor: float, | |
| low_frequency_factor: float, | |
| high_frequency_factor: float, | |
| device: torch.device | None = None, | |
| ) -> "RotaryPositionalEncoding": | |
| return RotaryPositionalEncoding( | |
| frequencies=expand_frequencies( | |
| rescale_frequencies( | |
| compute_frequencies(dim=head_dim, theta=theta, device=device), | |
| original_max_sequence_length=original_max_sequence_length, | |
| factor=factor, | |
| low_frequency_factor=low_frequency_factor, | |
| high_frequency_factor=high_frequency_factor, | |
| ), | |
| max_sequence_length=max_sequence_length, | |
| ), | |
| ) | |
| class RMSNorm(Module): | |
| def __init__( | |
| self, | |
| *, | |
| hidden_dim: int, | |
| epsilon: float = 1e-6, | |
| device: torch.device | None = None, | |
| ): | |
| super().__init__() | |
| self.epsilon = epsilon | |
| self.weight = nn.Parameter(torch.ones(hidden_dim, device=device)) | |
| def forward(self, x: Tensor): | |
| dtype = x.dtype | |
| x = x.float() | |
| x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon) | |
| return x.type(dtype) * self.weight | |
| class GatedFeedForward(Module): | |
| """ | |
| Gated Feed-Forward Network as used in LLaMA 3. | |
| ff(x) = W₂(act(W₁x) ⊙ W₃x) | |
| where ⊙ represents element-wise multiplication | |
| """ | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| feedforward_dim: int, | |
| activation_fn: Callable[[Tensor], Tensor] = silu, | |
| device: torch.device | None = None, | |
| ): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.feedforward_dim = feedforward_dim | |
| self.activation_fn = activation_fn | |
| # Projection matrices | |
| self.w1 = Linear(hidden_dim, feedforward_dim, bias=False, device=device) | |
| self.w2 = Linear(hidden_dim, feedforward_dim, bias=False, device=device) | |
| self.wo = Linear(feedforward_dim, hidden_dim, bias=False, device=device) | |
| def forward(self, x: Tensor) -> Tensor: | |
| assert x.dim() == 3, f"{x.dim()} != 3" | |
| # Split into two paths: non-linear and linear. | |
| # (B, T, C) -> (B, T, F) where F is the feedforward dimension | |
| nonlinear = self.activation_fn(self.w1(x)) | |
| linear = self.w2(x) | |
| # Combine paths with element-wise multiplication (gating mechanism) | |
| # (B, T, F) -> (B, T, F) | |
| x = nonlinear * linear | |
| # Project back to original dimension | |
| # (B, T, F) -> (B, T, C) | |
| x = self.wo(x) | |
| return x | |
| class KVCache: | |
| """ | |
| A tensor of shape (Batch, NumHeads, Time, HeadDim) of pre-computed key-value pairs. | |
| """ | |
| def __init__(self, keys: Tensor, values: Tensor): | |
| self.keys = keys | |
| self.values = values | |
| def push(self, keys: Tensor, values: Tensor) -> Tuple[Tensor, Tensor, "KVCache"]: | |
| k = torch.cat([self.keys, keys], dim=2) | |
| v = torch.cat([self.values, values], dim=2) | |
| return k, v, KVCache(k, v) | |
| @staticmethod | |
| def empty( | |
| *, | |
| batch: int, | |
| num_key_value_heads: int, | |
| key_value_head_dim: int, | |
| device: torch.device | None = None, | |
| dtype: torch.dtype | None = None, | |
| ) -> "KVCache": | |
| return KVCache( | |
| keys=torch.zeros( | |
| batch, | |
| num_key_value_heads, | |
| 0, | |
| key_value_head_dim, | |
| device=device, | |
| dtype=dtype, | |
| ), | |
| values=torch.zeros( | |
| batch, | |
| num_key_value_heads, | |
| 0, | |
| key_value_head_dim, | |
| device=device, | |
| dtype=dtype, | |
| ), | |
| ) | |
| @property | |
| def shape(self) -> Tuple[int, int, int, int]: | |
| return self.keys.shape # type: ignore | |
| @property | |
| def batch(self) -> int: | |
| return self.keys.shape[0] | |
| @property | |
| def num_key_value_heads(self) -> int: | |
| return self.keys.shape[1] | |
| @property | |
| def time(self) -> int: | |
| return self.keys.shape[2] | |
| @property | |
| def key_value_head_dim(self) -> int: | |
| return self.keys.shape[3] | |
| class PerLayerKVCache: | |
| def __init__(self, caches: List[KVCache]): | |
| self.caches = caches | |
| @staticmethod | |
| def empty( | |
| *, | |
| num_layers: int, | |
| batch: int, | |
| num_key_value_heads: int, | |
| key_value_head_dim: int, | |
| device: torch.device | None = None, | |
| dtype: torch.dtype | None = None, | |
| ) -> "PerLayerKVCache": | |
| assert num_layers > 0, f"expect {num_layers} > 0" | |
| return PerLayerKVCache( | |
| caches=[ | |
| KVCache.empty( | |
| batch=batch, | |
| num_key_value_heads=num_key_value_heads, | |
| key_value_head_dim=key_value_head_dim, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| @property | |
| def key_value_head_dim(self) -> int: | |
| return self.caches[0].key_value_head_dim | |
| @property | |
| def time(self) -> int: | |
| return self.caches[0].time | |
| @property | |
| def batch(self) -> int: | |
| return self.caches[0].batch | |
| @property | |
| def num_key_value_heads(self) -> int: | |
| return self.caches[0].num_key_value_heads | |
| @property | |
| def num_layers(self) -> int: | |
| return len(self.caches) | |
| def len(self) -> int: | |
| return len(self.caches) | |
| @staticmethod | |
| def size_bytes(num_key_value_heads: int, key_value_head_dim: int, num_layers: int, sequence_length: int, dtype: torch.dtype = torch.bfloat16) -> int: | |
| """What size in bytes does a KV-dim, L-layer, T-tokens KV-cache occupy""" | |
| return num_layers * ( | |
| sequence_length * num_key_value_heads * key_value_head_dim * dtype.itemsize * 2 | |
| ) | |
| def __getitem__(self, index: int) -> KVCache: | |
| return self.caches[index] | |
| def __setitem__(self, index: int, value: KVCache) -> None: | |
| self.caches[index] = value | |
| def causal_flex_attention_mask(document: Tensor) -> MaskFn: | |
| def mask(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: | |
| return (q_idx >= kv_idx) & (document[b, q_idx] == document[b, kv_idx]) | |
| return mask | |
| def full_flex_attention_mask(document: Tensor) -> MaskFn: | |
| def mask(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: | |
| return document[b, q_idx] == document[b, kv_idx] | |
| return mask | |
| class FlexAttentionMask: | |
| """ | |
| A wrapper around a Torch BlockMask with convenient utilities. | |
| """ | |
| def __init__( | |
| self, | |
| B: int | None, | |
| H: int | None, | |
| Q: int, | |
| KV: int, | |
| fn: MaskFn, | |
| *, | |
| device: torch.device, | |
| ): | |
| self.value = create_block_mask( | |
| mask_mod=fn, B=B, H=H, Q_LEN=Q, KV_LEN=KV, device=str(device) | |
| ) | |
| class PositionEncoder(Protocol): | |
| def __call__(self, q: Tensor, k: Tensor, *, offset: int, length: int) -> Tuple[Tensor, Tensor]: ... | |
| class MultiHeadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| hidden_dim: int, | |
| num_query_heads: int, | |
| num_key_value_heads: int, | |
| head_dim: int, | |
| position_encoder: PositionEncoder | None = None, | |
| device: torch.device | None = None, | |
| query_norm: Norm | None = None, | |
| key_norm: Norm | None = None, | |
| ): | |
| super().__init__() | |
| assert hidden_dim // num_query_heads == head_dim, f"{hidden_dim} // {num_query_heads} != {head_dim}" | |
| assert num_query_heads % num_key_value_heads == 0, f"{num_query_heads} % {num_key_value_heads} != 0" | |
| assert hidden_dim % num_query_heads == 0, f"{hidden_dim} % {num_query_heads} != 0" | |
| self.hidden_dim = hidden_dim | |
| self.num_query_heads = num_query_heads | |
| self.num_key_value_heads = num_key_value_heads | |
| self.head_dim = head_dim | |
| self.position_encoder = position_encoder | |
| # Q, K, V projection matrices | |
| self.wq = nn.Linear(hidden_dim, num_query_heads * head_dim, bias=False, device=device) | |
| self.wk = nn.Linear(hidden_dim, num_key_value_heads * head_dim, bias=False, device=device) | |
| self.wv = nn.Linear(hidden_dim, num_key_value_heads * head_dim, bias=False, device=device) | |
| # Output projection matrix | |
| self.wo = nn.Linear(num_query_heads * head_dim, hidden_dim, bias=False, device=device) | |
| # Q, K normalization | |
| self.wq_norm = query_norm | |
| self.wk_norm = key_norm | |
| def forward( | |
| self, | |
| x: Tensor, | |
| ex: Tensor | None = None, | |
| *, | |
| cache: KVCache | None = None, | |
| causal: bool = False, | |
| attention_mask: AttentionMask | None = None, | |
| flex_attention_mask: FlexAttentionMask | None = None, | |
| ) -> Tuple[Tensor, KVCache | None]: | |
| # (B, T, HiddenDim) | |
| assert x.dim() == 3 | |
| B, T, _ = x.shape | |
| assert not causal or (not attention_mask and not flex_attention_mask) | |
| # If cache is set, most likely we are doing autoregressive decoding | |
| # where T == 1 (predicting the next token) and CT + 1 is the actual | |
| # sequence length. | |
| if cache is not None: | |
| CT = cache.time | |
| else: | |
| CT = 0 | |
| # (B, T, HiddenDim) -> (B, T, HeadDim * NumQueryHeads) | |
| q = self.wq(x) | |
| # (B, T, HiddenDim) -> (B, T, HeadDim * NumKeyValueHeads) | |
| if ex is not None: | |
| k = self.wk(ex) | |
| v = self.wv(ex) | |
| else: | |
| k = self.wk(x) | |
| v = self.wv(x) | |
| # (B, T, HiddenDim) -> (B, T, NumQueryHeads, HeadDim) | |
| q = q.view(B, T, self.num_query_heads, self.head_dim) | |
| # (B, T, HiddenDim) -> (B, T, NumKeyValueHeads, HeadDim) | |
| k = k.view(B, T, self.num_key_value_heads, self.head_dim) | |
| v = v.view(B, T, self.num_key_value_heads, self.head_dim) | |
| # (B, T, NumQueryHeads, HeadDim) -> (B, NumQueryHeads, T, HeadDim) | |
| q = q.transpose(1, 2) | |
| # (B, T, NumKeyValueHeads, HeadDim) -> (B, NumKeyValueHeads, T, HeadDim) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| if self.wq_norm: | |
| q = self.wq_norm(q) | |
| if self.wk_norm: | |
| k = self.wk_norm(k) | |
| # Encode relative/absolute positional information before the attention computation. | |
| # (B, NumQueryHeads, T, HeadDim) -> (B, NumQueryHeads, T, HeadDim) | |
| # (B, NumKeyValueHeads, T, HeadDim) -> (B, NumKeyValueHeads, T, HeadDim) | |
| if self.position_encoder is not None: | |
| q, k = self.position_encoder(q, k, offset=CT, length=T) | |
| if cache is not None: | |
| k, v, cache = cache.push(k, v) | |
| if flex_attention_mask is not None: | |
| # (B, T, NumQueryHeads, HeadDim) -> (B, T, NumQueryHeads, HeadDim) | |
| x = flex_attention( | |
| q, | |
| k, | |
| v, | |
| scale=1.0 / (self.head_dim**0.5), | |
| enable_gqa=self.num_query_heads > self.num_key_value_heads, | |
| block_mask=flex_attention_mask.value, | |
| ) # type: ignore | |
| else: | |
| # (B, T, NumQueryHeads, HeadDim) -> (B, T, NumQueryHeads, HeadDim) | |
| x = scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| is_causal=causal, | |
| scale=1.0 / (self.head_dim**0.5), | |
| enable_gqa=self.num_query_heads > self.num_key_value_heads, | |
| attn_mask=attention_mask.value if attention_mask is not None else None, | |
| ) | |
| # (B, T, NumQueryHeads * HeadDim) -> (B, T, HiddenDim) | |
| x = x.transpose(1, 2).contiguous().view(B, T, self.hidden_dim) | |
| # (B, T, HiddenDim) -> (B, T, HiddenDim) | |
| x = self.wo(x) | |
| return x, cache | |
| class TransformerBlock(Module): | |
| def __init__( | |
| self, | |
| *, | |
| attention: Attention, | |
| attention_norm: Norm, | |
| feedforward: FeedForward, | |
| feedforward_norm: Norm, | |
| norm_placement: Literal["pre"] | Literal["post"] = "pre", | |
| norm_scale: float = 1.0, | |
| ): | |
| super().__init__() | |
| self.attention = attention | |
| self.attention_norm = attention_norm | |
| self.feedforward = feedforward | |
| self.feedforward_norm = feedforward_norm | |
| self.norm_placement = norm_placement | |
| self.norm_scale = norm_scale | |
| def forward( | |
| self, | |
| x: Tensor, | |
| *, | |
| cache: KVCache | None = None, | |
| causal: bool = False, | |
| attention_mask: AttentionMask | None = None, | |
| flex_attention_mask: FlexAttentionMask | None = None, | |
| ): | |
| # Attention | |
| residual = x | |
| if self.norm_placement == "pre": | |
| x = self.attention_norm(x) * self.norm_scale | |
| x, cache = self.attention( | |
| x, | |
| cache=cache, | |
| causal=causal, | |
| attention_mask=attention_mask, | |
| flex_attention_mask=flex_attention_mask, | |
| ) | |
| if self.norm_placement == "post": | |
| x = self.attention_norm(x) * self.norm_scale | |
| x = x + residual | |
| # Feedforward | |
| residual = x | |
| if self.norm_placement == "pre": | |
| x = self.feedforward_norm(x) * self.norm_scale | |
| x = self.feedforward(x) | |
| if self.norm_placement == "post": | |
| x = self.feedforward_norm(x) * self.norm_scale | |
| x = x + residual | |
| return x, cache |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment