Last active
May 30, 2025 17:56
-
-
Save amoudgl/858c3ba999d8be9af03062a3aadf7a79 to your computer and use it in GitHub Desktop.
NaNs in modula hello GPT tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Hello, GPT!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "In this notebook, we're going to build a transformer. In particular, we'll see how to define attention and residual blocks in Modula." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Getting the data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "First, let's download the Shakespeare dataset. The task will be to predict the next character." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "context = 64\n", | |
| "batch_size = 12\n", | |
| "\n", | |
| "from data.shakespeare import load_shakespeare\n", | |
| "\n", | |
| "data = load_shakespeare(context, batch_size)\n", | |
| "\n", | |
| "train_loader = data[\"train_loader\"]\n", | |
| "val_loader = data[\"val_loader\"]\n", | |
| "encode = data[\"encode\"]\n", | |
| "decode = data[\"decode\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let's peek at an example to verify the data loaded correctly!\n", | |
| "\n", | |
| "> NOTE: Simply commenting the code below leads to NaNs!! " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# for inputs, targets in train_loader:\n", | |
| "# print(\"Input shape:\", inputs.shape)\n", | |
| "# print(\"Target shape:\", targets.shape)\n", | |
| "# print(\"First input sequence:\", inputs[0][:10], \"...\")\n", | |
| "# print(\"First target sequence:\", targets[0][:10], \"...\")\n", | |
| "# print(\"\\nDecoded input:\", decode(inputs[0]))\n", | |
| "# print(\"\\nDecoded target:\", decode(targets[0]))\n", | |
| "# break" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Defining the architecture\n", | |
| "\n", | |
| "Let's use a very small setting for our transformer so it is fast to train." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# transformer hyperparameters\n", | |
| "\n", | |
| "vocab_size = 65\n", | |
| "num_heads = 4\n", | |
| "d_embed = 128\n", | |
| "d_query = 32\n", | |
| "d_value = 32\n", | |
| "num_blocks = 4\n", | |
| "attention_scale = 1\n", | |
| "final_scale = 1\n", | |
| "\n", | |
| "# training hyperparameters\n", | |
| "\n", | |
| "lr = 0.1\n", | |
| "beta = 0.95\n", | |
| "steps = 2001\n", | |
| "log_interval = 10\n", | |
| "val_interval = 100\n", | |
| "val_iters = 20" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| " Next up, we'll define the *attention* module and *residual blocks*." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Attention in Modula\n", | |
| "\n", | |
| "In Modula, we'll define attention by stringing together several bond modules to do the parameterless computations. The roadmap is:\n", | |
| "* Map `(batch, token, d_embed)` into `(batch, head, token, d_query)` (and same for key and value) via `Linear` and `SplitIntoHeads`\n", | |
| "* Use Rotary Positional Embeddings (RoPE) on the query and the key via `Rope`\n", | |
| "* Map `query` and `key` into attention similarities of shape `(batch, head, token, token)` via `AttentionQK`\n", | |
| "* Use a causal mask and then softmax to create attention scores via `CausalMask` and `Softmax`\n", | |
| "* Use the attention scores to create output vectors via `ApplyAttentionScores`, then `MergeHeads` and `Linear`\n", | |
| "\n", | |
| "The main difference to a standard transformer is that `AttentionQK` uses $1/d_\\text{head}$ scaling instead of the standard $1/\\sqrt{d_\\text{head}}$. The reason for this is to provide Lipschitz guarantees for attention that are independent of $d_\\text{head}$. For more information on this, see Appendix B.6 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).\n", | |
| "\n", | |
| "And here's the implementation:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from modula.atom import Linear\n", | |
| "from modula.bond import SplitIntoHeads, MergeHeads, Rope, AttentionQK, CausalMask, Softmax, ApplyAttentionScores, GeLU\n", | |
| "\n", | |
| "def Attention(num_heads, d_embed, d_query, d_value, attention_scale):\n", | |
| " \"\"\"Multi-head attention\"\"\"\n", | |
| "\n", | |
| " # For keys, queries, and values we add a heads dimension. For the out projection, we remove heads.\n", | |
| " # Remember modules compose right-to-left, and the order is Linear(d_out, d_in)! And @ means compose.\n", | |
| " Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)\n", | |
| " K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)\n", | |
| " V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed)\n", | |
| " W = Linear(d_embed, num_heads * d_value) @ MergeHeads()\n", | |
| "\n", | |
| " # Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose).\n", | |
| " AttentionScores = Softmax(attention_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K)\n", | |
| "\n", | |
| " # Read right-to-left: apply attention scores, multiply by 1/3 to fix the sensitivity to 1, project back to d_embed.\n", | |
| " return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let's check that the sensitivity is 1 at initialization." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CompositeModule\n", | |
| "...consists of 4 atoms and 10 bonds\n", | |
| "...smooth\n", | |
| "...input sensitivity is 1.0\n", | |
| "...contributes proportion 4 to feature learning of any supermodule\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(Attention(num_heads, d_embed, d_query, d_value, attention_scale))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Residual blocks in Modula\n", | |
| "\n", | |
| "To implement the rest of our transformer, the roadmap is:\n", | |
| "* Embed the input tokens\n", | |
| "* Apply residual blocks for attention and the MLP\n", | |
| "* Project out\n", | |
| "\n", | |
| "All that's left is to set up the residual blocks. In Modula, we define residual connections using a convex combination. If $L$ is the number of residual blocks, then we use a convex combination of the identity and the block to get $x \\mapsto \\frac{L-1}{L} \\cdot x + \\frac{1}{L} \\cdot \\textsf{block}(x)$. The purpose is to create a Lipschitz guarantee that is independent of the number of blocks. For more information, see Proposition 4 of [Scalable Optimization in the Modular Norm](https://arxiv.org/pdf/2405.14813).\n", | |
| "\n", | |
| "In short, these changes enable Lipschitz guarantees on our transformer even as we scale the width and the depth!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from modula.abstract import Identity\n", | |
| "from modula.atom import Embed\n", | |
| "\n", | |
| "def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0):\n", | |
| " # Set embed to have mass 1. This controls the proportion of feature learning that it contributes to the whole network.\n", | |
| " embed = Embed(d_embed, vocab_size)\n", | |
| " embed.tare()\n", | |
| "\n", | |
| " # Let's create attention and MLP layers. \n", | |
| " att = Attention(num_heads, d_embed, d_query, d_value, attention_scale)\n", | |
| " mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed)\n", | |
| "\n", | |
| " # For our residual connections, L = 2*num_blocks because each block has two residual connections.\n", | |
| " att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att\n", | |
| " mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp\n", | |
| "\n", | |
| " # We can use powers of a module to compose it with itself many times!\n", | |
| " blocks = (mlp_block @ att_block) ** num_blocks\n", | |
| "\n", | |
| " # Set all transformer blocks to have mass 5 (by default).\n", | |
| " # So 5/7 of the change in the network output is due to the blocks,\n", | |
| " # and 2/7 of the change in output is due to the embedding and out projection.\n", | |
| " blocks.tare(absolute=blocks_mass)\n", | |
| "\n", | |
| " out = final_scale * Linear(vocab_size, d_embed)\n", | |
| "\n", | |
| " return out @ blocks @ embed" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "And finally we are ready to construct our GPT!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CompositeModule\n", | |
| "...consists of 26 atoms and 78 bonds\n", | |
| "...non-smooth\n", | |
| "...input sensitivity is 1.0\n", | |
| "...contributes proportion 7.0 to feature learning of any supermodule\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model = GPT(\n", | |
| " vocab_size=vocab_size,\n", | |
| " num_heads=num_heads,\n", | |
| " d_embed=d_embed,\n", | |
| " d_query=d_query,\n", | |
| " d_value=d_value,\n", | |
| " num_blocks=num_blocks,\n", | |
| " attention_scale=attention_scale,\n", | |
| " final_scale=final_scale,\n", | |
| ")\n", | |
| "\n", | |
| "model.jit()\n", | |
| "\n", | |
| "print(model)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Loss function and training\n", | |
| "\n", | |
| "To train our transformer we'll use cross entropy loss, which we can compute by decomposing the softmax:\n", | |
| "\n", | |
| "$$\n", | |
| "-\\log(\\text{target probability}) = -\\log(\\text{softmax}(\\text{logits})_\\text{target}) = -\\text{logit}_\\text{target} + \\text{log\\,sum\\,exp}(\\text{logits})\n", | |
| "$$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "\n", | |
| "def cross_entropy_loss(w, inputs, targets):\n", | |
| " # We use the logsumexp trick for stable cross entropy\n", | |
| " logits = model(inputs, w) # shape is [batch, seq_len, vocab_size]\n", | |
| " batch_indices = jnp.arange(logits.shape[0])[:, None] # shape is [batch, 1]\n", | |
| " seq_indices = jnp.arange(logits.shape[1])[None, :] # shape is [1, seq_len]\n", | |
| " # This indexing selects out logits[b, s, targets[b, s]], which is the target logit\n", | |
| " losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1) # shape is [batch, seq_len]\n", | |
| " return losses.mean()\n", | |
| "\n", | |
| "loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "And we're ready to train!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Step 0: loss 4.211885929107666\n", | |
| "--> val loss 4.184442043304443\n", | |
| "Step 10: loss 3.8278777599334717\n", | |
| "Step 20: loss 3.3490588665008545\n", | |
| "Step 30: loss 2.8786635398864746\n", | |
| "Step 40: loss 2.6745896339416504\n", | |
| "Step 50: loss 2.5865437984466553\n", | |
| "Step 60: loss 2.4899280071258545\n", | |
| "Step 70: loss 2.440074920654297\n", | |
| "Step 80: loss 2.4192352294921875\n", | |
| "Step 90: loss 2.3117661476135254\n", | |
| "Step 100: loss 2.2782480716705322\n", | |
| "--> val loss 2.475813150405884\n", | |
| "Step 110: loss 2.3092970848083496\n", | |
| "Step 120: loss 2.21390962600708\n", | |
| "Step 130: loss 2.2222862243652344\n", | |
| "Step 140: loss 2.185128927230835\n", | |
| "Step 150: loss 2.2528843879699707\n", | |
| "Step 160: loss 2.1431965827941895\n", | |
| "Step 170: loss 2.1925888061523438\n", | |
| "Step 180: loss 2.05013370513916\n", | |
| "Step 190: loss 2.1693406105041504\n", | |
| "Step 200: loss 2.1141762733459473\n", | |
| "--> val loss 2.4127883911132812\n", | |
| "Step 210: loss 2.1186375617980957\n", | |
| "Step 220: loss 2.0049819946289062\n", | |
| "Step 230: loss 2.041961431503296\n", | |
| "Step 240: loss 2.111858606338501\n", | |
| "Step 250: loss 2.063030958175659\n", | |
| "Step 260: loss 2.090083599090576\n", | |
| "Step 270: loss 2.0722665786743164\n", | |
| "Step 280: loss 2.0284886360168457\n", | |
| "Step 290: loss 2.1009130477905273\n", | |
| "Step 300: loss 2.0008883476257324\n", | |
| "--> val loss 2.291325092315674\n", | |
| "Step 310: loss 1.9710001945495605\n", | |
| "Step 320: loss 2.068572521209717\n", | |
| "Step 330: loss 2.0272347927093506\n", | |
| "Step 340: loss 1.9743661880493164\n", | |
| "Step 350: loss 1.9835456609725952\n", | |
| "Step 360: loss 2.019561290740967\n", | |
| "Step 370: loss 1.9405639171600342\n", | |
| "Step 380: loss 2.068528890609741\n", | |
| "Step 390: loss 2.02536940574646\n", | |
| "Step 400: loss 2.03733491897583\n", | |
| "--> val loss 2.2335243225097656\n", | |
| "Step 410: loss 2.0036675930023193\n", | |
| "Step 420: loss 1.9687211513519287\n", | |
| "Step 430: loss 2.123520851135254\n", | |
| "Step 440: loss 1.9799261093139648\n", | |
| "Step 450: loss 2.12276029586792\n", | |
| "Step 460: loss 1.9917802810668945\n", | |
| "Step 470: loss 2.0934362411499023\n", | |
| "Step 480: loss 1.9377912282943726\n", | |
| "Step 490: loss 1.9672636985778809\n", | |
| "Step 500: loss 1.9683382511138916\n", | |
| "--> val loss 2.19697904586792\n", | |
| "Step 510: loss 1.9429147243499756\n", | |
| "Step 520: loss 1.8921594619750977\n", | |
| "Step 530: loss 1.9347004890441895\n", | |
| "Step 540: loss 1.896270513534546\n", | |
| "Step 550: loss 1.9985806941986084\n", | |
| "Step 560: loss 1.9955267906188965\n", | |
| "Step 570: loss 1.8528060913085938\n", | |
| "Step 580: loss 1.9202473163604736\n", | |
| "Step 590: loss 2.0072526931762695\n", | |
| "Step 600: loss 1.9544605016708374\n", | |
| "--> val loss 2.307926654815674\n", | |
| "Step 610: loss 1.9368000030517578\n", | |
| "Step 620: loss 2.0403528213500977\n", | |
| "Step 630: loss 1.9623057842254639\n", | |
| "Step 640: loss 1.9277417659759521\n", | |
| "Step 650: loss 2.0852506160736084\n", | |
| "Step 660: loss 1.966705322265625\n", | |
| "Step 670: loss 1.9932379722595215\n", | |
| "Step 680: loss 1.971732497215271\n", | |
| "Step 690: loss 1.9563077688217163\n", | |
| "Step 700: loss 2.00211238861084\n", | |
| "--> val loss 2.1681482791900635\n", | |
| "Step 710: loss 2.0447287559509277\n", | |
| "Step 720: loss 2.0726253986358643\n", | |
| "Step 730: loss 2.017390251159668\n", | |
| "Step 740: loss 1.847938060760498\n", | |
| "Step 750: loss 2.018876552581787\n", | |
| "Step 760: loss 2.0033011436462402\n", | |
| "Step 770: loss 1.9202569723129272\n", | |
| "Step 780: loss nan\n", | |
| "Step 790: loss nan\n", | |
| "Step 800: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 810: loss nan\n", | |
| "Step 820: loss nan\n", | |
| "Step 830: loss nan\n", | |
| "Step 840: loss nan\n", | |
| "Step 850: loss nan\n", | |
| "Step 860: loss nan\n", | |
| "Step 870: loss nan\n", | |
| "Step 880: loss nan\n", | |
| "Step 890: loss nan\n", | |
| "Step 900: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 910: loss nan\n", | |
| "Step 920: loss nan\n", | |
| "Step 930: loss nan\n", | |
| "Step 940: loss nan\n", | |
| "Step 950: loss nan\n", | |
| "Step 960: loss nan\n", | |
| "Step 970: loss nan\n", | |
| "Step 980: loss nan\n", | |
| "Step 990: loss nan\n", | |
| "Step 1000: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1010: loss nan\n", | |
| "Step 1020: loss nan\n", | |
| "Step 1030: loss nan\n", | |
| "Step 1040: loss nan\n", | |
| "Step 1050: loss nan\n", | |
| "Step 1060: loss nan\n", | |
| "Step 1070: loss nan\n", | |
| "Step 1080: loss nan\n", | |
| "Step 1090: loss nan\n", | |
| "Step 1100: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1110: loss nan\n", | |
| "Step 1120: loss nan\n", | |
| "Step 1130: loss nan\n", | |
| "Step 1140: loss nan\n", | |
| "Step 1150: loss nan\n", | |
| "Step 1160: loss nan\n", | |
| "Step 1170: loss nan\n", | |
| "Step 1180: loss nan\n", | |
| "Step 1190: loss nan\n", | |
| "Step 1200: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1210: loss nan\n", | |
| "Step 1220: loss nan\n", | |
| "Step 1230: loss nan\n", | |
| "Step 1240: loss nan\n", | |
| "Step 1250: loss nan\n", | |
| "Step 1260: loss nan\n", | |
| "Step 1270: loss nan\n", | |
| "Step 1280: loss nan\n", | |
| "Step 1290: loss nan\n", | |
| "Step 1300: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1310: loss nan\n", | |
| "Step 1320: loss nan\n", | |
| "Step 1330: loss nan\n", | |
| "Step 1340: loss nan\n", | |
| "Step 1350: loss nan\n", | |
| "Step 1360: loss nan\n", | |
| "Step 1370: loss nan\n", | |
| "Step 1380: loss nan\n", | |
| "Step 1390: loss nan\n", | |
| "Step 1400: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1410: loss nan\n", | |
| "Step 1420: loss nan\n", | |
| "Step 1430: loss nan\n", | |
| "Step 1440: loss nan\n", | |
| "Step 1450: loss nan\n", | |
| "Step 1460: loss nan\n", | |
| "Step 1470: loss nan\n", | |
| "Step 1480: loss nan\n", | |
| "Step 1490: loss nan\n", | |
| "Step 1500: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1510: loss nan\n", | |
| "Step 1520: loss nan\n", | |
| "Step 1530: loss nan\n", | |
| "Step 1540: loss nan\n", | |
| "Step 1550: loss nan\n", | |
| "Step 1560: loss nan\n", | |
| "Step 1570: loss nan\n", | |
| "Step 1580: loss nan\n", | |
| "Step 1590: loss nan\n", | |
| "Step 1600: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1610: loss nan\n", | |
| "Step 1620: loss nan\n", | |
| "Step 1630: loss nan\n", | |
| "Step 1640: loss nan\n", | |
| "Step 1650: loss nan\n", | |
| "Step 1660: loss nan\n", | |
| "Step 1670: loss nan\n", | |
| "Step 1680: loss nan\n", | |
| "Step 1690: loss nan\n", | |
| "Step 1700: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1710: loss nan\n", | |
| "Step 1720: loss nan\n", | |
| "Step 1730: loss nan\n", | |
| "Step 1740: loss nan\n", | |
| "Step 1750: loss nan\n", | |
| "Step 1760: loss nan\n", | |
| "Step 1770: loss nan\n", | |
| "Step 1780: loss nan\n", | |
| "Step 1790: loss nan\n", | |
| "Step 1800: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1810: loss nan\n", | |
| "Step 1820: loss nan\n", | |
| "Step 1830: loss nan\n", | |
| "Step 1840: loss nan\n", | |
| "Step 1850: loss nan\n", | |
| "Step 1860: loss nan\n", | |
| "Step 1870: loss nan\n", | |
| "Step 1880: loss nan\n", | |
| "Step 1890: loss nan\n", | |
| "Step 1900: loss nan\n", | |
| "--> val loss nan\n", | |
| "Step 1910: loss nan\n", | |
| "Step 1920: loss nan\n", | |
| "Step 1930: loss nan\n", | |
| "Step 1940: loss nan\n", | |
| "Step 1950: loss nan\n", | |
| "Step 1960: loss nan\n", | |
| "Step 1970: loss nan\n", | |
| "Step 1980: loss nan\n", | |
| "Step 1990: loss nan\n", | |
| "Step 2000: loss nan\n", | |
| "--> val loss nan\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "key = jax.random.PRNGKey(0)\n", | |
| "w = model.initialize(key)\n", | |
| "\n", | |
| "step = 0\n", | |
| "momentum = [0 * weight for weight in w]\n", | |
| "lr_schedule = lambda step: lr * (steps - step) / steps\n", | |
| "for inputs, targets in train_loader:\n", | |
| " loss, grad_w = loss_and_grad(w, inputs, targets)\n", | |
| " momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)]\n", | |
| " d_w = model.dualize(momentum)\n", | |
| " w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)]\n", | |
| "\n", | |
| " if step % log_interval == 0:\n", | |
| " print(f\"Step {step}: loss {loss}\")\n", | |
| " \n", | |
| " if step % val_interval == 0:\n", | |
| " val_losses = []\n", | |
| " for val_inputs, val_targets in val_loader:\n", | |
| " loss, _ = loss_and_grad(w, val_inputs, val_targets)\n", | |
| " val_losses.append(loss)\n", | |
| " if len(val_losses) >= val_iters:\n", | |
| " break\n", | |
| " print(f\"--> val loss {sum(val_losses)/len(val_losses)}\")\n", | |
| "\n", | |
| " step += 1\n", | |
| "\n", | |
| " if step >= steps:\n", | |
| " break" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "mod", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.16" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment