Skip to content

Instantly share code, notes, and snippets.

@dhbrojas
Last active July 27, 2025 21:17
Show Gist options
  • Select an option

  • Save dhbrojas/a75666524003f3076c65ed619e28c6e1 to your computer and use it in GitHub Desktop.

Select an option

Save dhbrojas/a75666524003f3076c65ed619e28c6e1 to your computer and use it in GitHub Desktop.
from typing import Callable, Protocol
import torch
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn.functional import silu
def compute_frequencies(
*,
dim: int,
theta: int,
device: torch.device | None = None,
dtype: torch.dtype | None = torch.float32,
) -> Tensor:
"""
A mapping `() -> (dim // 2)`
"""
return 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype, device=device) / dim))
def expand_frequencies(
frequencies: Tensor, *, max_sequence_length: int, dtype: torch.dtype = torch.float32
) -> Tensor:
"""
A mapping `(dim // 2) -> (len, dim // 2)` that expand the frequencies up to the
maximum sequence length of the model.
"""
return torch.outer(torch.arange(max_sequence_length, device=frequencies.device), frequencies).to(dtype)
def rescale_frequencies(
frequencies: Tensor,
*,
original_max_sequence_length: int,
factor: float,
low_frequency_factor: float,
high_frequency_factor: float,
) -> Tensor:
"""
A mapping `(len, dim // 2) -> (len, dim // 2)` that rescales the frequencies for models
trained on shorter sequences.
"""
low_frequency_wavelength = original_max_sequence_length / low_frequency_factor
high_frequency_wavelength = original_max_sequence_length / high_frequency_factor
wavelength = 2 * math.pi / frequencies
frequencies = torch.where(
wavelength > low_frequency_wavelength, frequencies / factor, frequencies
)
smooth_factor = (
original_max_sequence_length / wavelength - low_frequency_factor
) / (high_frequency_factor - low_frequency_factor)
smoothed_frequencies = (
1 - smooth_factor
) * frequencies / factor + smooth_factor * frequencies
is_medium_frequency = ~(wavelength < high_frequency_wavelength) * ~(
wavelength > low_frequency_wavelength
)
frequencies = torch.where(is_medium_frequency, smoothed_frequencies, frequencies)
return frequencies
def rotate(x: Tensor, frequencies: Tensor) -> Tensor:
"""
A mapping `(*, length), (length, dim // 2) -> (*, length, dim)` that
rotates features of based on their position in the sequence.
"""
assert x.dim() >= 2, f"{x.dim()} >= 2"
x_length, x_dim = x.shape[-2:]
assert frequencies.dim() == 2, f"{frequencies.dim()} == 2"
f_length, f_half_dim = frequencies.shape
assert f_length == x_length, f"{f_length} == {x_length}"
assert x_dim // 2 == f_half_dim, f"{x_dim} // 2 == {f_half_dim}"
def rotate_half(x: Tensor) -> Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
cos, sin = (
torch.cos(frequencies).unsqueeze(0).unsqueeze(1),
torch.sin(frequencies).unsqueeze(0).unsqueeze(1),
)
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
x = (x * cos) + (rotate_half(x) * sin)
return x
class RotaryPositionalEncoding(Module):
def __init__(self, *, frequencies: Tensor):
super().__init__()
self.frequencies: Tensor
self.register_buffer(
"frequencies",
frequencies,
persistent=False,
)
def forward(
self, q: Tensor, k: Tensor, *, offset: int, length: int
) -> Tuple[Tensor, Tensor]:
with torch.autocast(device_type=q.device.type, enabled=False):
return (
rotate(q, self.frequencies[offset : offset + length]),
rotate(k, self.frequencies[offset : offset + length]),
)
@staticmethod
def new(
*,
head_dim: int,
theta: int,
max_sequence_length: int,
device: torch.device | None = None,
) -> "RotaryPositionalEncoding":
return RotaryPositionalEncoding(
frequencies=expand_frequencies(
compute_frequencies(dim=head_dim, theta=theta, device=device),
max_sequence_length=max_sequence_length,
),
)
@staticmethod
def new_rescaled(
*,
head_dim: int,
theta: int,
max_sequence_length: int,
original_max_sequence_length: int,
factor: float,
low_frequency_factor: float,
high_frequency_factor: float,
device: torch.device | None = None,
) -> "RotaryPositionalEncoding":
return RotaryPositionalEncoding(
frequencies=expand_frequencies(
rescale_frequencies(
compute_frequencies(dim=head_dim, theta=theta, device=device),
original_max_sequence_length=original_max_sequence_length,
factor=factor,
low_frequency_factor=low_frequency_factor,
high_frequency_factor=high_frequency_factor,
),
max_sequence_length=max_sequence_length,
),
)
class RMSNorm(Module):
def __init__(
self,
*,
hidden_dim: int,
epsilon: float = 1e-6,
device: torch.device | None = None,
):
super().__init__()
self.epsilon = epsilon
self.weight = nn.Parameter(torch.ones(hidden_dim, device=device))
def forward(self, x: Tensor):
dtype = x.dtype
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon)
return x.type(dtype) * self.weight
class GatedFeedForward(Module):
"""
Gated Feed-Forward Network as used in LLaMA 3.
ff(x) = W₂(act(W₁x) ⊙ W₃x)
where ⊙ represents element-wise multiplication
"""
def __init__(
self,
hidden_dim: int,
feedforward_dim: int,
activation_fn: Callable[[Tensor], Tensor] = silu,
device: torch.device | None = None,
):
super().__init__()
self.hidden_dim = hidden_dim
self.feedforward_dim = feedforward_dim
self.activation_fn = activation_fn
# Projection matrices
self.w1 = Linear(hidden_dim, feedforward_dim, bias=False, device=device)
self.w2 = Linear(hidden_dim, feedforward_dim, bias=False, device=device)
self.wo = Linear(feedforward_dim, hidden_dim, bias=False, device=device)
def forward(self, x: Tensor) -> Tensor:
assert x.dim() == 3, f"{x.dim()} != 3"
# Split into two paths: non-linear and linear.
# (B, T, C) -> (B, T, F) where F is the feedforward dimension
nonlinear = self.activation_fn(self.w1(x))
linear = self.w2(x)
# Combine paths with element-wise multiplication (gating mechanism)
# (B, T, F) -> (B, T, F)
x = nonlinear * linear
# Project back to original dimension
# (B, T, F) -> (B, T, C)
x = self.wo(x)
return x
class KVCache:
"""
A tensor of shape (Batch, NumHeads, Time, HeadDim) of pre-computed key-value pairs.
"""
def __init__(self, keys: Tensor, values: Tensor):
self.keys = keys
self.values = values
def push(self, keys: Tensor, values: Tensor) -> Tuple[Tensor, Tensor, "KVCache"]:
k = torch.cat([self.keys, keys], dim=2)
v = torch.cat([self.values, values], dim=2)
return k, v, KVCache(k, v)
@staticmethod
def empty(
*,
batch: int,
num_key_value_heads: int,
key_value_head_dim: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> "KVCache":
return KVCache(
keys=torch.zeros(
batch,
num_key_value_heads,
0,
key_value_head_dim,
device=device,
dtype=dtype,
),
values=torch.zeros(
batch,
num_key_value_heads,
0,
key_value_head_dim,
device=device,
dtype=dtype,
),
)
@property
def shape(self) -> Tuple[int, int, int, int]:
return self.keys.shape # type: ignore
@property
def batch(self) -> int:
return self.keys.shape[0]
@property
def num_key_value_heads(self) -> int:
return self.keys.shape[1]
@property
def time(self) -> int:
return self.keys.shape[2]
@property
def key_value_head_dim(self) -> int:
return self.keys.shape[3]
class PerLayerKVCache:
def __init__(self, caches: List[KVCache]):
self.caches = caches
@staticmethod
def empty(
*,
num_layers: int,
batch: int,
num_key_value_heads: int,
key_value_head_dim: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> "PerLayerKVCache":
assert num_layers > 0, f"expect {num_layers} > 0"
return PerLayerKVCache(
caches=[
KVCache.empty(
batch=batch,
num_key_value_heads=num_key_value_heads,
key_value_head_dim=key_value_head_dim,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
]
)
@property
def key_value_head_dim(self) -> int:
return self.caches[0].key_value_head_dim
@property
def time(self) -> int:
return self.caches[0].time
@property
def batch(self) -> int:
return self.caches[0].batch
@property
def num_key_value_heads(self) -> int:
return self.caches[0].num_key_value_heads
@property
def num_layers(self) -> int:
return len(self.caches)
def len(self) -> int:
return len(self.caches)
@staticmethod
def size_bytes(num_key_value_heads: int, key_value_head_dim: int, num_layers: int, sequence_length: int, dtype: torch.dtype = torch.bfloat16) -> int:
"""What size in bytes does a KV-dim, L-layer, T-tokens KV-cache occupy"""
return num_layers * (
sequence_length * num_key_value_heads * key_value_head_dim * dtype.itemsize * 2
)
def __getitem__(self, index: int) -> KVCache:
return self.caches[index]
def __setitem__(self, index: int, value: KVCache) -> None:
self.caches[index] = value
def causal_flex_attention_mask(document: Tensor) -> MaskFn:
def mask(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor:
return (q_idx >= kv_idx) & (document[b, q_idx] == document[b, kv_idx])
return mask
def full_flex_attention_mask(document: Tensor) -> MaskFn:
def mask(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor:
return document[b, q_idx] == document[b, kv_idx]
return mask
class FlexAttentionMask:
"""
A wrapper around a Torch BlockMask with convenient utilities.
"""
def __init__(
self,
B: int | None,
H: int | None,
Q: int,
KV: int,
fn: MaskFn,
*,
device: torch.device,
):
self.value = create_block_mask(
mask_mod=fn, B=B, H=H, Q_LEN=Q, KV_LEN=KV, device=str(device)
)
class PositionEncoder(Protocol):
def __call__(self, q: Tensor, k: Tensor, *, offset: int, length: int) -> Tuple[Tensor, Tensor]: ...
class MultiHeadAttention(nn.Module):
def __init__(
self,
*,
hidden_dim: int,
num_query_heads: int,
num_key_value_heads: int,
head_dim: int,
position_encoder: PositionEncoder | None = None,
device: torch.device | None = None,
query_norm: Norm | None = None,
key_norm: Norm | None = None,
):
super().__init__()
assert hidden_dim // num_query_heads == head_dim, f"{hidden_dim} // {num_query_heads} != {head_dim}"
assert num_query_heads % num_key_value_heads == 0, f"{num_query_heads} % {num_key_value_heads} != 0"
assert hidden_dim % num_query_heads == 0, f"{hidden_dim} % {num_query_heads} != 0"
self.hidden_dim = hidden_dim
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.position_encoder = position_encoder
# Q, K, V projection matrices
self.wq = nn.Linear(hidden_dim, num_query_heads * head_dim, bias=False, device=device)
self.wk = nn.Linear(hidden_dim, num_key_value_heads * head_dim, bias=False, device=device)
self.wv = nn.Linear(hidden_dim, num_key_value_heads * head_dim, bias=False, device=device)
# Output projection matrix
self.wo = nn.Linear(num_query_heads * head_dim, hidden_dim, bias=False, device=device)
# Q, K normalization
self.wq_norm = query_norm
self.wk_norm = key_norm
def forward(
self,
x: Tensor,
ex: Tensor | None = None,
*,
cache: KVCache | None = None,
causal: bool = False,
attention_mask: AttentionMask | None = None,
flex_attention_mask: FlexAttentionMask | None = None,
) -> Tuple[Tensor, KVCache | None]:
# (B, T, HiddenDim)
assert x.dim() == 3
B, T, _ = x.shape
assert not causal or (not attention_mask and not flex_attention_mask)
# If cache is set, most likely we are doing autoregressive decoding
# where T == 1 (predicting the next token) and CT + 1 is the actual
# sequence length.
if cache is not None:
CT = cache.time
else:
CT = 0
# (B, T, HiddenDim) -> (B, T, HeadDim * NumQueryHeads)
q = self.wq(x)
# (B, T, HiddenDim) -> (B, T, HeadDim * NumKeyValueHeads)
if ex is not None:
k = self.wk(ex)
v = self.wv(ex)
else:
k = self.wk(x)
v = self.wv(x)
# (B, T, HiddenDim) -> (B, T, NumQueryHeads, HeadDim)
q = q.view(B, T, self.num_query_heads, self.head_dim)
# (B, T, HiddenDim) -> (B, T, NumKeyValueHeads, HeadDim)
k = k.view(B, T, self.num_key_value_heads, self.head_dim)
v = v.view(B, T, self.num_key_value_heads, self.head_dim)
# (B, T, NumQueryHeads, HeadDim) -> (B, NumQueryHeads, T, HeadDim)
q = q.transpose(1, 2)
# (B, T, NumKeyValueHeads, HeadDim) -> (B, NumKeyValueHeads, T, HeadDim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if self.wq_norm:
q = self.wq_norm(q)
if self.wk_norm:
k = self.wk_norm(k)
# Encode relative/absolute positional information before the attention computation.
# (B, NumQueryHeads, T, HeadDim) -> (B, NumQueryHeads, T, HeadDim)
# (B, NumKeyValueHeads, T, HeadDim) -> (B, NumKeyValueHeads, T, HeadDim)
if self.position_encoder is not None:
q, k = self.position_encoder(q, k, offset=CT, length=T)
if cache is not None:
k, v, cache = cache.push(k, v)
if flex_attention_mask is not None:
# (B, T, NumQueryHeads, HeadDim) -> (B, T, NumQueryHeads, HeadDim)
x = flex_attention(
q,
k,
v,
scale=1.0 / (self.head_dim**0.5),
enable_gqa=self.num_query_heads > self.num_key_value_heads,
block_mask=flex_attention_mask.value,
) # type: ignore
else:
# (B, T, NumQueryHeads, HeadDim) -> (B, T, NumQueryHeads, HeadDim)
x = scaled_dot_product_attention(
q,
k,
v,
is_causal=causal,
scale=1.0 / (self.head_dim**0.5),
enable_gqa=self.num_query_heads > self.num_key_value_heads,
attn_mask=attention_mask.value if attention_mask is not None else None,
)
# (B, T, NumQueryHeads * HeadDim) -> (B, T, HiddenDim)
x = x.transpose(1, 2).contiguous().view(B, T, self.hidden_dim)
# (B, T, HiddenDim) -> (B, T, HiddenDim)
x = self.wo(x)
return x, cache
class TransformerBlock(Module):
def __init__(
self,
*,
attention: Attention,
attention_norm: Norm,
feedforward: FeedForward,
feedforward_norm: Norm,
norm_placement: Literal["pre"] | Literal["post"] = "pre",
norm_scale: float = 1.0,
):
super().__init__()
self.attention = attention
self.attention_norm = attention_norm
self.feedforward = feedforward
self.feedforward_norm = feedforward_norm
self.norm_placement = norm_placement
self.norm_scale = norm_scale
def forward(
self,
x: Tensor,
*,
cache: KVCache | None = None,
causal: bool = False,
attention_mask: AttentionMask | None = None,
flex_attention_mask: FlexAttentionMask | None = None,
):
# Attention
residual = x
if self.norm_placement == "pre":
x = self.attention_norm(x) * self.norm_scale
x, cache = self.attention(
x,
cache=cache,
causal=causal,
attention_mask=attention_mask,
flex_attention_mask=flex_attention_mask,
)
if self.norm_placement == "post":
x = self.attention_norm(x) * self.norm_scale
x = x + residual
# Feedforward
residual = x
if self.norm_placement == "pre":
x = self.feedforward_norm(x) * self.norm_scale
x = self.feedforward(x)
if self.norm_placement == "post":
x = self.feedforward_norm(x) * self.norm_scale
x = x + residual
return x, cache
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment