Skip to content

Instantly share code, notes, and snippets.

@macleginn
Created March 6, 2026 12:27
Show Gist options
  • Select an option

  • Save macleginn/5b11098473132001482cdf42510ed153 to your computer and use it in GitHub Desktop.

Select an option

Save macleginn/5b11098473132001482cdf42510ed153 to your computer and use it in GitHub Desktop.
"""
Conceptual pseudocode explaining how sliding-window attention works
using loops. This is NOT meant to be efficient or runnable in a real
model. It simply illustrates what PyTorch operations like:
unfold
unsqueeze
squeeze
batched matrix multiplication
are doing conceptually.
We assume:
Q : (seq_len, q_dim)
K : (seq_len, q_dim)
V : (seq_len, v_dim)
and a sliding window of size w.
For simplicity this example shows the bidirectional case
(attend to i-w ... i+w). The causal case is identical except
the window becomes [i-w ... i].
"""
import math
# ------------------------------------------------------------
# Configuration
# ------------------------------------------------------------
seq_len = 10
q_dim = 4
v_dim = 6
w = 2
local_len = 2 * w + 1
# Dummy tensors (lists used here for conceptual clarity)
Q = [[0]*q_dim for _ in range(seq_len)]
K = [[0]*q_dim for _ in range(seq_len)]
V = [[0]*v_dim for _ in range(seq_len)]
# ------------------------------------------------------------
# Step 1: Pad K and V so edge tokens can still form full windows
# ------------------------------------------------------------
"""
Padding conceptually extends the sequence:
original K:
k0 k1 k2 k3 k4 ...
padded K:
pad pad k0 k1 k2 k3 k4 pad pad
"""
pad_key = [0] * q_dim
pad_val = [0] * v_dim
K_padded = [pad_key]*w + K + [pad_key]*w
V_padded = [pad_val]*w + V + [pad_val]*w
# ------------------------------------------------------------
# Step 2: "unfold" – extract sliding windows of keys and values
# ------------------------------------------------------------
"""
PyTorch code:
K_padded.unfold(-2, local_len, 1)
Conceptually builds a tensor:
local_K[i] = keys in window around token i
Example (w=2):
i=0 -> [pad, pad, k0, k1, k2]
i=1 -> [pad, k0, k1, k2, k3]
i=2 -> [k0, k1, k2, k3, k4]
"""
local_K = []
local_V = []
for i in range(seq_len):
window_keys = []
window_vals = []
for r in range(local_len):
j = i + r # index in padded sequence
window_keys.append(K_padded[j])
window_vals.append(V_padded[j])
local_K.append(window_keys)
local_V.append(window_vals)
# Shapes conceptually:
# local_K : (seq_len, local_len, q_dim)
# local_V : (seq_len, local_len, v_dim)
# ------------------------------------------------------------
# Step 3: Compute dot-product attention scores
# ------------------------------------------------------------
"""
PyTorch equivalent:
scores = (Q.unsqueeze(-2) @ local_K.transpose(-2,-1)).squeeze(-2)
Conceptually:
scores[i][r] = dot(Q[i], local_K[i][r])
"""
scores = [[0]*local_len for _ in range(seq_len)]
for i in range(seq_len):
for r in range(local_len):
score = 0
for d in range(q_dim):
score += Q[i][d] * local_K[i][r][d]
scores[i][r] = score / math.sqrt(q_dim)
# ------------------------------------------------------------
# Step 4: Mask positions outside the real sequence
# ------------------------------------------------------------
"""
Some window entries correspond to padded tokens.
Example:
query i=0, w=2
window = [pad, pad, k0, k1, k2]
We must mask the pads so they cannot receive attention.
"""
for i in range(seq_len):
for r in range(local_len):
relative_offset = r - w
key_position = i + relative_offset
if key_position < 0 or key_position >= seq_len:
scores[i][r] = float("-inf")
# ------------------------------------------------------------
# Step 5: Softmax over local window
# ------------------------------------------------------------
"""
Normalize attention weights within each window.
"""
attention = [[0]*local_len for _ in range(seq_len)]
for i in range(seq_len):
# compute normalization constant
denom = 0
for r in range(local_len):
if scores[i][r] != float("-inf"):
denom += math.exp(scores[i][r])
# compute probabilities
for r in range(local_len):
if scores[i][r] == float("-inf"):
attention[i][r] = 0
else:
attention[i][r] = math.exp(scores[i][r]) / denom
# ------------------------------------------------------------
# Step 6: Weighted sum of values
# ------------------------------------------------------------
"""
PyTorch equivalent:
output = (attn.unsqueeze(-2) @ local_V).squeeze(-2)
Conceptually:
output[i] = sum_r attention[i][r] * local_V[i][r]
"""
output = [[0]*v_dim for _ in range(seq_len)]
for i in range(seq_len):
for r in range(local_len):
weight = attention[i][r]
for d in range(v_dim):
output[i][d] += weight * local_V[i][r][d]
# ------------------------------------------------------------
# Final result
# ------------------------------------------------------------
"""
output shape:
(seq_len, v_dim)
Each token representation is a weighted sum of values
from its sliding window.
"""
print("Sliding-window attention computed conceptually.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment