Skip to content

Instantly share code, notes, and snippets.

@amoudgl
Last active May 30, 2025 17:56
Show Gist options
  • Select an option

  • Save amoudgl/858c3ba999d8be9af03062a3aadf7a79 to your computer and use it in GitHub Desktop.

Select an option

Save amoudgl/858c3ba999d8be9af03062a3aadf7a79 to your computer and use it in GitHub Desktop.
NaNs in modula hello GPT tutorial
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hello, GPT!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we're going to build a transformer. In particular, we'll see how to define attention and residual blocks in Modula."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Getting the data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's download the Shakespeare dataset. The task will be to predict the next character."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"context = 64\n",
"batch_size = 12\n",
"\n",
"from data.shakespeare import load_shakespeare\n",
"\n",
"data = load_shakespeare(context, batch_size)\n",
"\n",
"train_loader = data[\"train_loader\"]\n",
"val_loader = data[\"val_loader\"]\n",
"encode = data[\"encode\"]\n",
"decode = data[\"decode\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's peek at an example to verify the data loaded correctly!\n",
"\n",
"> NOTE: Simply commenting the code below leads to NaNs!! "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for inputs, targets in train_loader:\n",
"# print(\"Input shape:\", inputs.shape)\n",
"# print(\"Target shape:\", targets.shape)\n",
"# print(\"First input sequence:\", inputs[0][:10], \"...\")\n",
"# print(\"First target sequence:\", targets[0][:10], \"...\")\n",
"# print(\"\\nDecoded input:\", decode(inputs[0]))\n",
"# print(\"\\nDecoded target:\", decode(targets[0]))\n",
"# break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining the architecture\n",
"\n",
"Let's use a very small setting for our transformer so it is fast to train."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# transformer hyperparameters\n",
"\n",
"vocab_size = 65\n",
"num_heads = 4\n",
"d_embed = 128\n",
"d_query = 32\n",
"d_value = 32\n",
"num_blocks = 4\n",
"attention_scale = 1\n",
"final_scale = 1\n",
"\n",
"# training hyperparameters\n",
"\n",
"lr = 0.1\n",
"beta = 0.95\n",
"steps = 2001\n",
"log_interval = 10\n",
"val_interval = 100\n",
"val_iters = 20"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" Next up, we'll define the *attention* module and *residual blocks*."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Attention in Modula\n",
"\n",
"In Modula, we'll define attention by stringing together several bond modules to do the parameterless computations. The roadmap is:\n",
"* Map `(batch, token, d_embed)` into `(batch, head, token, d_query)` (and same for key and value) via `Linear` and `SplitIntoHeads`\n",
"* Use Rotary Positional Embeddings (RoPE) on the query and the key via `Rope`\n",
"* Map `query` and `key` into attention similarities of shape `(batch, head, token, token)` via `AttentionQK`\n",
"* Use a causal mask and then softmax to create attention scores via `CausalMask` and `Softmax`\n",
"* Use the attention scores to create output vectors via `ApplyAttentionScores`, then `MergeHeads` and `Linear`\n",
"\n",
"The main difference to a standard transformer is that `AttentionQK` uses $1/d_\\text{head}$ scaling instead of the standard $1/\\sqrt{d_\\text{head}}$. The reason for this is to provide Lipschitz guarantees for attention that are independent of $d_\\text{head}$. For more information on this, see Appendix B.6 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).\n",
"\n",
"And here's the implementation:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from modula.atom import Linear\n",
"from modula.bond import SplitIntoHeads, MergeHeads, Rope, AttentionQK, CausalMask, Softmax, ApplyAttentionScores, GeLU\n",
"\n",
"def Attention(num_heads, d_embed, d_query, d_value, attention_scale):\n",
" \"\"\"Multi-head attention\"\"\"\n",
"\n",
" # For keys, queries, and values we add a heads dimension. For the out projection, we remove heads.\n",
" # Remember modules compose right-to-left, and the order is Linear(d_out, d_in)! And @ means compose.\n",
" Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)\n",
" K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)\n",
" V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed)\n",
" W = Linear(d_embed, num_heads * d_value) @ MergeHeads()\n",
"\n",
" # Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose).\n",
" AttentionScores = Softmax(attention_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K)\n",
"\n",
" # Read right-to-left: apply attention scores, multiply by 1/3 to fix the sensitivity to 1, project back to d_embed.\n",
" return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check that the sensitivity is 1 at initialization."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CompositeModule\n",
"...consists of 4 atoms and 10 bonds\n",
"...smooth\n",
"...input sensitivity is 1.0\n",
"...contributes proportion 4 to feature learning of any supermodule\n"
]
}
],
"source": [
"print(Attention(num_heads, d_embed, d_query, d_value, attention_scale))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Residual blocks in Modula\n",
"\n",
"To implement the rest of our transformer, the roadmap is:\n",
"* Embed the input tokens\n",
"* Apply residual blocks for attention and the MLP\n",
"* Project out\n",
"\n",
"All that's left is to set up the residual blocks. In Modula, we define residual connections using a convex combination. If $L$ is the number of residual blocks, then we use a convex combination of the identity and the block to get $x \\mapsto \\frac{L-1}{L} \\cdot x + \\frac{1}{L} \\cdot \\textsf{block}(x)$. The purpose is to create a Lipschitz guarantee that is independent of the number of blocks. For more information, see Proposition 4 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).\n",
"\n",
"In short, these changes enable Lipschitz guarantees on our transformer even as we scale the width and the depth!"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from modula.abstract import Identity\n",
"from modula.atom import Embed\n",
"\n",
"def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0):\n",
" # Set embed to have mass 1. This controls the proportion of feature learning that it contributes to the whole network.\n",
" embed = Embed(d_embed, vocab_size)\n",
" embed.tare()\n",
"\n",
" # Let's create attention and MLP layers. \n",
" att = Attention(num_heads, d_embed, d_query, d_value, attention_scale)\n",
" mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed)\n",
"\n",
" # For our residual connections, L = 2*num_blocks because each block has two residual connections.\n",
" att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att\n",
" mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp\n",
"\n",
" # We can use powers of a module to compose it with itself many times!\n",
" blocks = (mlp_block @ att_block) ** num_blocks\n",
"\n",
" # Set all transformer blocks to have mass 5 (by default).\n",
" # So 5/7 of the change in the network output is due to the blocks,\n",
" # and 2/7 of the change in output is due to the embedding and out projection.\n",
" blocks.tare(absolute=blocks_mass)\n",
"\n",
" out = final_scale * Linear(vocab_size, d_embed)\n",
"\n",
" return out @ blocks @ embed"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And finally we are ready to construct our GPT!"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CompositeModule\n",
"...consists of 26 atoms and 78 bonds\n",
"...non-smooth\n",
"...input sensitivity is 1.0\n",
"...contributes proportion 7.0 to feature learning of any supermodule\n"
]
}
],
"source": [
"model = GPT(\n",
" vocab_size=vocab_size,\n",
" num_heads=num_heads,\n",
" d_embed=d_embed,\n",
" d_query=d_query,\n",
" d_value=d_value,\n",
" num_blocks=num_blocks,\n",
" attention_scale=attention_scale,\n",
" final_scale=final_scale,\n",
")\n",
"\n",
"model.jit()\n",
"\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loss function and training\n",
"\n",
"To train our transformer we'll use cross entropy loss, which we can compute by decomposing the softmax:\n",
"\n",
"$$\n",
"-\\log(\\text{target probability}) = -\\log(\\text{softmax}(\\text{logits})_\\text{target}) = -\\text{logit}_\\text{target} + \\text{log\\,sum\\,exp}(\\text{logits})\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"def cross_entropy_loss(w, inputs, targets):\n",
" # We use the logsumexp trick for stable cross entropy\n",
" logits = model(inputs, w) # shape is [batch, seq_len, vocab_size]\n",
" batch_indices = jnp.arange(logits.shape[0])[:, None] # shape is [batch, 1]\n",
" seq_indices = jnp.arange(logits.shape[1])[None, :] # shape is [1, seq_len]\n",
" # This indexing selects out logits[b, s, targets[b, s]], which is the target logit\n",
" losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1) # shape is [batch, seq_len]\n",
" return losses.mean()\n",
"\n",
"loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we're ready to train!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 0: loss 4.211885929107666\n",
"--> val loss 4.184442043304443\n",
"Step 10: loss 3.8278777599334717\n",
"Step 20: loss 3.3490588665008545\n",
"Step 30: loss 2.8786635398864746\n",
"Step 40: loss 2.6745896339416504\n",
"Step 50: loss 2.5865437984466553\n",
"Step 60: loss 2.4899280071258545\n",
"Step 70: loss 2.440074920654297\n",
"Step 80: loss 2.4192352294921875\n",
"Step 90: loss 2.3117661476135254\n",
"Step 100: loss 2.2782480716705322\n",
"--> val loss 2.475813150405884\n",
"Step 110: loss 2.3092970848083496\n",
"Step 120: loss 2.21390962600708\n",
"Step 130: loss 2.2222862243652344\n",
"Step 140: loss 2.185128927230835\n",
"Step 150: loss 2.2528843879699707\n",
"Step 160: loss 2.1431965827941895\n",
"Step 170: loss 2.1925888061523438\n",
"Step 180: loss 2.05013370513916\n",
"Step 190: loss 2.1693406105041504\n",
"Step 200: loss 2.1141762733459473\n",
"--> val loss 2.4127883911132812\n",
"Step 210: loss 2.1186375617980957\n",
"Step 220: loss 2.0049819946289062\n",
"Step 230: loss 2.041961431503296\n",
"Step 240: loss 2.111858606338501\n",
"Step 250: loss 2.063030958175659\n",
"Step 260: loss 2.090083599090576\n",
"Step 270: loss 2.0722665786743164\n",
"Step 280: loss 2.0284886360168457\n",
"Step 290: loss 2.1009130477905273\n",
"Step 300: loss 2.0008883476257324\n",
"--> val loss 2.291325092315674\n",
"Step 310: loss 1.9710001945495605\n",
"Step 320: loss 2.068572521209717\n",
"Step 330: loss 2.0272347927093506\n",
"Step 340: loss 1.9743661880493164\n",
"Step 350: loss 1.9835456609725952\n",
"Step 360: loss 2.019561290740967\n",
"Step 370: loss 1.9405639171600342\n",
"Step 380: loss 2.068528890609741\n",
"Step 390: loss 2.02536940574646\n",
"Step 400: loss 2.03733491897583\n",
"--> val loss 2.2335243225097656\n",
"Step 410: loss 2.0036675930023193\n",
"Step 420: loss 1.9687211513519287\n",
"Step 430: loss 2.123520851135254\n",
"Step 440: loss 1.9799261093139648\n",
"Step 450: loss 2.12276029586792\n",
"Step 460: loss 1.9917802810668945\n",
"Step 470: loss 2.0934362411499023\n",
"Step 480: loss 1.9377912282943726\n",
"Step 490: loss 1.9672636985778809\n",
"Step 500: loss 1.9683382511138916\n",
"--> val loss 2.19697904586792\n",
"Step 510: loss 1.9429147243499756\n",
"Step 520: loss 1.8921594619750977\n",
"Step 530: loss 1.9347004890441895\n",
"Step 540: loss 1.896270513534546\n",
"Step 550: loss 1.9985806941986084\n",
"Step 560: loss 1.9955267906188965\n",
"Step 570: loss 1.8528060913085938\n",
"Step 580: loss 1.9202473163604736\n",
"Step 590: loss 2.0072526931762695\n",
"Step 600: loss 1.9544605016708374\n",
"--> val loss 2.307926654815674\n",
"Step 610: loss 1.9368000030517578\n",
"Step 620: loss 2.0403528213500977\n",
"Step 630: loss 1.9623057842254639\n",
"Step 640: loss 1.9277417659759521\n",
"Step 650: loss 2.0852506160736084\n",
"Step 660: loss 1.966705322265625\n",
"Step 670: loss 1.9932379722595215\n",
"Step 680: loss 1.971732497215271\n",
"Step 690: loss 1.9563077688217163\n",
"Step 700: loss 2.00211238861084\n",
"--> val loss 2.1681482791900635\n",
"Step 710: loss 2.0447287559509277\n",
"Step 720: loss 2.0726253986358643\n",
"Step 730: loss 2.017390251159668\n",
"Step 740: loss 1.847938060760498\n",
"Step 750: loss 2.018876552581787\n",
"Step 760: loss 2.0033011436462402\n",
"Step 770: loss 1.9202569723129272\n",
"Step 780: loss nan\n",
"Step 790: loss nan\n",
"Step 800: loss nan\n",
"--> val loss nan\n",
"Step 810: loss nan\n",
"Step 820: loss nan\n",
"Step 830: loss nan\n",
"Step 840: loss nan\n",
"Step 850: loss nan\n",
"Step 860: loss nan\n",
"Step 870: loss nan\n",
"Step 880: loss nan\n",
"Step 890: loss nan\n",
"Step 900: loss nan\n",
"--> val loss nan\n",
"Step 910: loss nan\n",
"Step 920: loss nan\n",
"Step 930: loss nan\n",
"Step 940: loss nan\n",
"Step 950: loss nan\n",
"Step 960: loss nan\n",
"Step 970: loss nan\n",
"Step 980: loss nan\n",
"Step 990: loss nan\n",
"Step 1000: loss nan\n",
"--> val loss nan\n",
"Step 1010: loss nan\n",
"Step 1020: loss nan\n",
"Step 1030: loss nan\n",
"Step 1040: loss nan\n",
"Step 1050: loss nan\n",
"Step 1060: loss nan\n",
"Step 1070: loss nan\n",
"Step 1080: loss nan\n",
"Step 1090: loss nan\n",
"Step 1100: loss nan\n",
"--> val loss nan\n",
"Step 1110: loss nan\n",
"Step 1120: loss nan\n",
"Step 1130: loss nan\n",
"Step 1140: loss nan\n",
"Step 1150: loss nan\n",
"Step 1160: loss nan\n",
"Step 1170: loss nan\n",
"Step 1180: loss nan\n",
"Step 1190: loss nan\n",
"Step 1200: loss nan\n",
"--> val loss nan\n",
"Step 1210: loss nan\n",
"Step 1220: loss nan\n",
"Step 1230: loss nan\n",
"Step 1240: loss nan\n",
"Step 1250: loss nan\n",
"Step 1260: loss nan\n",
"Step 1270: loss nan\n",
"Step 1280: loss nan\n",
"Step 1290: loss nan\n",
"Step 1300: loss nan\n",
"--> val loss nan\n",
"Step 1310: loss nan\n",
"Step 1320: loss nan\n",
"Step 1330: loss nan\n",
"Step 1340: loss nan\n",
"Step 1350: loss nan\n",
"Step 1360: loss nan\n",
"Step 1370: loss nan\n",
"Step 1380: loss nan\n",
"Step 1390: loss nan\n",
"Step 1400: loss nan\n",
"--> val loss nan\n",
"Step 1410: loss nan\n",
"Step 1420: loss nan\n",
"Step 1430: loss nan\n",
"Step 1440: loss nan\n",
"Step 1450: loss nan\n",
"Step 1460: loss nan\n",
"Step 1470: loss nan\n",
"Step 1480: loss nan\n",
"Step 1490: loss nan\n",
"Step 1500: loss nan\n",
"--> val loss nan\n",
"Step 1510: loss nan\n",
"Step 1520: loss nan\n",
"Step 1530: loss nan\n",
"Step 1540: loss nan\n",
"Step 1550: loss nan\n",
"Step 1560: loss nan\n",
"Step 1570: loss nan\n",
"Step 1580: loss nan\n",
"Step 1590: loss nan\n",
"Step 1600: loss nan\n",
"--> val loss nan\n",
"Step 1610: loss nan\n",
"Step 1620: loss nan\n",
"Step 1630: loss nan\n",
"Step 1640: loss nan\n",
"Step 1650: loss nan\n",
"Step 1660: loss nan\n",
"Step 1670: loss nan\n",
"Step 1680: loss nan\n",
"Step 1690: loss nan\n",
"Step 1700: loss nan\n",
"--> val loss nan\n",
"Step 1710: loss nan\n",
"Step 1720: loss nan\n",
"Step 1730: loss nan\n",
"Step 1740: loss nan\n",
"Step 1750: loss nan\n",
"Step 1760: loss nan\n",
"Step 1770: loss nan\n",
"Step 1780: loss nan\n",
"Step 1790: loss nan\n",
"Step 1800: loss nan\n",
"--> val loss nan\n",
"Step 1810: loss nan\n",
"Step 1820: loss nan\n",
"Step 1830: loss nan\n",
"Step 1840: loss nan\n",
"Step 1850: loss nan\n",
"Step 1860: loss nan\n",
"Step 1870: loss nan\n",
"Step 1880: loss nan\n",
"Step 1890: loss nan\n",
"Step 1900: loss nan\n",
"--> val loss nan\n",
"Step 1910: loss nan\n",
"Step 1920: loss nan\n",
"Step 1930: loss nan\n",
"Step 1940: loss nan\n",
"Step 1950: loss nan\n",
"Step 1960: loss nan\n",
"Step 1970: loss nan\n",
"Step 1980: loss nan\n",
"Step 1990: loss nan\n",
"Step 2000: loss nan\n",
"--> val loss nan\n"
]
}
],
"source": [
"key = jax.random.PRNGKey(0)\n",
"w = model.initialize(key)\n",
"\n",
"step = 0\n",
"momentum = [0 * weight for weight in w]\n",
"lr_schedule = lambda step: lr * (steps - step) / steps\n",
"for inputs, targets in train_loader:\n",
" loss, grad_w = loss_and_grad(w, inputs, targets)\n",
" momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)]\n",
" d_w = model.dualize(momentum)\n",
" w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)]\n",
"\n",
" if step % log_interval == 0:\n",
" print(f\"Step {step}: loss {loss}\")\n",
" \n",
" if step % val_interval == 0:\n",
" val_losses = []\n",
" for val_inputs, val_targets in val_loader:\n",
" loss, _ = loss_and_grad(w, val_inputs, val_targets)\n",
" val_losses.append(loss)\n",
" if len(val_losses) >= val_iters:\n",
" break\n",
" print(f\"--> val loss {sum(val_losses)/len(val_losses)}\")\n",
"\n",
" step += 1\n",
"\n",
" if step >= steps:\n",
" break"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mod",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment