Created
February 14, 2026 13:00
-
-
Save mushtaq/c4439e3e76d4ef5981dcefba6831fca5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| //Scala port of MicroGPT gist by @karpathy | |
| import java.nio.file.{Files, Paths} | |
| import scala.collection.mutable | |
| import scala.jdk.CollectionConverters.* | |
| extension (r: util.Random) | |
| // Weighted random selection, implementation of Python's random.choices | |
| def choices[T](population: Seq[T], weights: Seq[Double]): T = | |
| var sum = 0.0 | |
| val threshold = r.nextDouble() | |
| population.zip(weights).find: (_, w) => | |
| sum += w | |
| sum >= threshold | |
| .map(_._1).getOrElse(population.last) | |
| given Conversion[Double, Value] = Value(_) | |
| // Let there be Autograd, to recursively apply the chain rule through a computation graph | |
| class Value( | |
| var data: Double, // scalar value of this node calculated during forward pass | |
| val children: Seq[Value] = Seq(), // children of this node in the computation graph | |
| val grads: Seq[Double] = Seq() // local derivative of this node w.r.t. its children | |
| ): | |
| var grad: Double = 0.0 // derivative of the loss w.r.t. this node, calculated in backward pass | |
| def +(other: Value): Value = Value(data + other.data, Seq(this, other), Seq(1.0, 1.0)) | |
| def *(other: Value): Value = Value(data * other.data, Seq(this, other), Seq(other.data, data)) | |
| def ^(other: Double): Value = Value(math.pow(data, other), Seq(this), Seq(other * math.pow(data, other - 1))) | |
| def log: Value = Value(math.log(data), Seq(this), Seq(1.0 / data)) | |
| def exp: Value = Value(math.exp(data), Seq(this), Seq(math.exp(data))) | |
| def relu: Value = Value(math.max(0, data), Seq(this), Seq(if data > 0 then 1.0 else 0.0)) | |
| def unary_- : Value = this * Value(-1) | |
| def -(other: Value): Value = this + (-other) | |
| def /(other: Value): Value = this * (other ^ -1) | |
| def backward(): Unit = | |
| var topo = List[Value]() | |
| val visited = mutable.Set[Value]() | |
| def buildTopo(v: Value): Unit = | |
| if !visited.contains(v) then | |
| visited.add(v) | |
| v.children.foreach(buildTopo) | |
| topo = v :: topo | |
| buildTopo(this) | |
| grad = 1.0 | |
| for v <- topo do | |
| for (child, local_grad) <- v.children.zip(v.grads) do | |
| child.grad += local_grad * v.grad | |
| @main def runMicroGPT(): Unit = | |
| val random = util.Random(42) // Let there be order among chaos | |
| // Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names) | |
| val inputPath = Paths.get("input.txt") | |
| val namesURI = java.net.URI("https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt") | |
| if !Files.exists(inputPath) then | |
| util.Using.resource(namesURI.toURL.openStream()): stream => | |
| Files.copy(stream, inputPath) | |
| val docs = Files.readAllLines(inputPath).asScala.map(_.trim).filter(_.nonEmpty).toSeq // list[str] of documents | |
| val shuffledDocs = util.Random.shuffle(docs) | |
| println(s"num docs: ${shuffledDocs.size}") | |
| // Let there be a Tokenizer to translate strings to discrete symbols and back | |
| val uChars = shuffledDocs.mkString.toSet.toList.sorted // unique characters in the dataset become token ids 0..n-1 | |
| val BOS = uChars.size // token id for the special Beginning of Sequence (BOS) token | |
| val vocabSize = uChars.size + 1 // total number of unique tokens, +1 is for BOS | |
| println(s"vocab size: $vocabSize") | |
| // Initialize the parameters, to store the knowledge of the model. | |
| val nEmbd = 16 // embedding dimension | |
| val nHead = 4 // number of attention heads | |
| val nLayer = 1 // number of layers | |
| val blockSize = 16 // maximum sequence length | |
| val headDim = nEmbd / nHead // dimension of each head | |
| def matrix(nOut: Int, nIn: Int, std: Double = 0.08): Array[Array[Value]] = | |
| Array.fill(nOut)(Array.fill(nIn)(Value(random.nextGaussian() * std))) | |
| val stateDict = mutable.Map[String, Array[Array[Value]]]() | |
| stateDict("wte") = matrix(vocabSize, nEmbd) | |
| stateDict("wpe") = matrix(blockSize, nEmbd) | |
| stateDict("lm_head") = matrix(vocabSize, nEmbd) | |
| for i <- 0 until nLayer do | |
| stateDict(s"layer$i.attn_wq") = matrix(nEmbd, nEmbd) | |
| stateDict(s"layer$i.attn_wk") = matrix(nEmbd, nEmbd) | |
| stateDict(s"layer$i.attn_wv") = matrix(nEmbd, nEmbd) | |
| stateDict(s"layer$i.attn_wo") = matrix(nEmbd, nEmbd) | |
| stateDict(s"layer$i.mlp_fc1") = matrix(4 * nEmbd, nEmbd) | |
| stateDict(s"layer$i.mlp_fc2") = matrix(nEmbd, 4 * nEmbd) | |
| val params = stateDict.values.flatten.flatten.toSeq // flatten params into a single list[Value] | |
| println(s"num params: ${params.size}") | |
| // Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next. | |
| // Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsNorm, no biases, GeLU -> ReLU | |
| def linear(x: Seq[Value], w: Array[Array[Value]]): Seq[Value] = | |
| w.map(wo => wo.zip(x).map(_ * _).reduce(_ + _)).toSeq | |
| def softMax(logits: Seq[Value]): Seq[Value] = | |
| val maxVal = logits.map(_.data).max | |
| val exps = logits.map(v => (v - maxVal).exp) | |
| val total = exps.reduce(_ + _) | |
| exps.map(_ / total) | |
| def rmsNorm(x: Seq[Value]): Seq[Value] = | |
| val ms = x.map(xi => xi * xi).reduce(_ + _) / x.size.toDouble | |
| val scale = (ms + 1e-5) ^ -0.5 | |
| x.map(_ * scale) | |
| def gpt(tokenId: Int, posId: Int, keys: Array[mutable.ListBuffer[Seq[Value]]], values: Array[mutable.ListBuffer[Seq[Value]]]): Seq[Value] = | |
| val tokEmbRow = stateDict("wte")(tokenId).toSeq // token embedding | |
| val posEmbRow = stateDict("wpe")(posId).toSeq // position embedding | |
| var x = tokEmbRow.zip(posEmbRow).map(_ + _) // joint token and position embedding | |
| x = rmsNorm(x) | |
| for li <- 0 until nLayer do | |
| // 1) Multi-head attention block | |
| val xResidual = x | |
| x = rmsNorm(x) | |
| val q = linear(x, stateDict(s"layer$li.attn_wq")) | |
| val k = linear(x, stateDict(s"layer$li.attn_wk")) | |
| val v = linear(x, stateDict(s"layer$li.attn_wv")) | |
| keys(li).append(k) | |
| values(li).append(v) | |
| val xAttn = mutable.ListBuffer[Value]() | |
| for h <- 0 until nHead do | |
| val hs = h * headDim | |
| val qH = q.slice(hs, hs + headDim) | |
| val kH = keys(li).map(_.slice(hs, hs + headDim)) | |
| val vH = values(li).map(_.slice(hs, hs + headDim)) | |
| val attnLogits = kH.map(kt => qH.zip(kt).map(_ * _).reduce(_ + _) / math.sqrt(headDim)).toSeq | |
| val attnWeights = softMax(attnLogits) | |
| val headOut = (0 until headDim).map: j => | |
| attnWeights.zip(vH).map: | |
| (att, vh) => att * vh(j) | |
| .reduce(_ + _) | |
| xAttn.appendAll(headOut) | |
| x = linear(xAttn.toSeq, stateDict(s"layer$li.attn_wo")) | |
| x = x.zip(xResidual).map(_ + _) | |
| // 2) MLP block | |
| val xResidual2 = x | |
| x = rmsNorm(x) | |
| x = linear(x, stateDict(s"layer$li.mlp_fc1")) | |
| x = x.map(_.relu) | |
| x = linear(x, stateDict(s"layer$li.mlp_fc2")) | |
| x = x.zip(xResidual2).map(_ + _) | |
| linear(x, stateDict("lm_head")) | |
| // Let there be Adam, the blessed optimizer and its buffers | |
| val learningRate = 0.01 | |
| val beta1 = 0.85 | |
| val beta2 = 0.99 | |
| val epsAdam = 1e-8 | |
| val m = Array.fill(params.size)(0.0) // first moment buffer | |
| val v = Array.fill(params.size)(0.0) // second moment buffer | |
| // Repeat in sequence | |
| val numSteps = 1000 // number of training steps | |
| for step <- 0 until numSteps do | |
| // Take single document, tokenize it, surround it with BOS special token on both sides | |
| val doc = shuffledDocs(step % shuffledDocs.size) | |
| val tokens = BOS +: doc.map(ch => uChars.indexOf(ch)) :+ BOS | |
| val n = math.min(blockSize, tokens.length - 1) | |
| // Forward the token sequence through the model, building up the computation graph all the way to the loss. | |
| val keys = Array.fill(nLayer)(mutable.ListBuffer[Seq[Value]]()) | |
| val values = Array.fill(nLayer)(mutable.ListBuffer[Seq[Value]]()) | |
| val losses = mutable.ListBuffer[Value]() | |
| for posId <- 0 until n do | |
| val tokenId = tokens(posId) | |
| val targetId = tokens(posId + 1) | |
| val logits = gpt(tokenId, posId, keys, values) | |
| val probs = softMax(logits) | |
| val lossT = -probs(targetId).log | |
| losses.append(lossT) | |
| val loss = losses.reduce(_ + _) * (1.0 / n) // final average loss over the document sequence. May yours be low. | |
| // Backward the loss, calculating the gradients with respect to all model parameters. | |
| loss.backward() | |
| // Adam optimizer update: update the model parameters based on the corresponding gradients. | |
| val lrT = learningRate * (1 - step.toDouble / numSteps) // linear learning rate decay | |
| for (p, i) <- params.zipWithIndex do | |
| m(i) = beta1 * m(i) + (1 - beta1) * p.grad | |
| v(i) = beta2 * v(i) + (1 - beta2) * (p.grad * p.grad) | |
| val mHat = m(i) / (1 - math.pow(beta1, step + 1)) | |
| val vHat = v(i) / (1 - math.pow(beta2, step + 1)) | |
| p.data -= lrT * mHat / (math.sqrt(vHat) + epsAdam) | |
| p.grad = 0.0 | |
| if (step + 1) % 10 == 0 then | |
| println(f"step ${step+1}%4d / $numSteps | loss ${loss.data}%.4f") | |
| // Inference: may the model babble back to us | |
| println("\n--- inference (new, hallucinated names) ---") | |
| val temperature = 0.5 // in (0, 1], control the "creativity" of generated text, low to high | |
| for sampleIdx <- 0 until 20 do | |
| val keys = Array.fill(nLayer)(mutable.ListBuffer[Seq[Value]]()) | |
| val values = Array.fill(nLayer)(mutable.ListBuffer[Seq[Value]]()) | |
| var tokenId = BOS | |
| val sample = mutable.StringBuilder() | |
| var continue = true | |
| var posId = 0 | |
| while continue && posId < blockSize do | |
| val logits = gpt(tokenId, posId, keys, values) | |
| val probsSeq = softMax(logits.map(l => l / temperature)).map(_.data) | |
| tokenId = random.choices(0 until vocabSize, probsSeq) | |
| if tokenId == BOS then | |
| continue = false | |
| else | |
| sample.append(uChars(tokenId)) | |
| posId += 1 | |
| println(f"sample ${sampleIdx+1}%2d: ${sample.toString}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment