Skip to content

Instantly share code, notes, and snippets.

@mushtaq
Created February 14, 2026 13:00
Show Gist options
  • Select an option

  • Save mushtaq/c4439e3e76d4ef5981dcefba6831fca5 to your computer and use it in GitHub Desktop.

Select an option

Save mushtaq/c4439e3e76d4ef5981dcefba6831fca5 to your computer and use it in GitHub Desktop.
//Scala port of MicroGPT gist by @karpathy
import java.nio.file.{Files, Paths}
import scala.collection.mutable
import scala.jdk.CollectionConverters.*
extension (r: util.Random)
// Weighted random selection, implementation of Python's random.choices
def choices[T](population: Seq[T], weights: Seq[Double]): T =
var sum = 0.0
val threshold = r.nextDouble()
population.zip(weights).find: (_, w) =>
sum += w
sum >= threshold
.map(_._1).getOrElse(population.last)
given Conversion[Double, Value] = Value(_)
// Let there be Autograd, to recursively apply the chain rule through a computation graph
class Value(
var data: Double, // scalar value of this node calculated during forward pass
val children: Seq[Value] = Seq(), // children of this node in the computation graph
val grads: Seq[Double] = Seq() // local derivative of this node w.r.t. its children
):
var grad: Double = 0.0 // derivative of the loss w.r.t. this node, calculated in backward pass
def +(other: Value): Value = Value(data + other.data, Seq(this, other), Seq(1.0, 1.0))
def *(other: Value): Value = Value(data * other.data, Seq(this, other), Seq(other.data, data))
def ^(other: Double): Value = Value(math.pow(data, other), Seq(this), Seq(other * math.pow(data, other - 1)))
def log: Value = Value(math.log(data), Seq(this), Seq(1.0 / data))
def exp: Value = Value(math.exp(data), Seq(this), Seq(math.exp(data)))
def relu: Value = Value(math.max(0, data), Seq(this), Seq(if data > 0 then 1.0 else 0.0))
def unary_- : Value = this * Value(-1)
def -(other: Value): Value = this + (-other)
def /(other: Value): Value = this * (other ^ -1)
def backward(): Unit =
var topo = List[Value]()
val visited = mutable.Set[Value]()
def buildTopo(v: Value): Unit =
if !visited.contains(v) then
visited.add(v)
v.children.foreach(buildTopo)
topo = v :: topo
buildTopo(this)
grad = 1.0
for v <- topo do
for (child, local_grad) <- v.children.zip(v.grads) do
child.grad += local_grad * v.grad
@main def runMicroGPT(): Unit =
val random = util.Random(42) // Let there be order among chaos
// Let there be an input dataset `docs`: list[str] of documents (e.g. a dataset of names)
val inputPath = Paths.get("input.txt")
val namesURI = java.net.URI("https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt")
if !Files.exists(inputPath) then
util.Using.resource(namesURI.toURL.openStream()): stream =>
Files.copy(stream, inputPath)
val docs = Files.readAllLines(inputPath).asScala.map(_.trim).filter(_.nonEmpty).toSeq // list[str] of documents
val shuffledDocs = util.Random.shuffle(docs)
println(s"num docs: ${shuffledDocs.size}")
// Let there be a Tokenizer to translate strings to discrete symbols and back
val uChars = shuffledDocs.mkString.toSet.toList.sorted // unique characters in the dataset become token ids 0..n-1
val BOS = uChars.size // token id for the special Beginning of Sequence (BOS) token
val vocabSize = uChars.size + 1 // total number of unique tokens, +1 is for BOS
println(s"vocab size: $vocabSize")
// Initialize the parameters, to store the knowledge of the model.
val nEmbd = 16 // embedding dimension
val nHead = 4 // number of attention heads
val nLayer = 1 // number of layers
val blockSize = 16 // maximum sequence length
val headDim = nEmbd / nHead // dimension of each head
def matrix(nOut: Int, nIn: Int, std: Double = 0.08): Array[Array[Value]] =
Array.fill(nOut)(Array.fill(nIn)(Value(random.nextGaussian() * std)))
val stateDict = mutable.Map[String, Array[Array[Value]]]()
stateDict("wte") = matrix(vocabSize, nEmbd)
stateDict("wpe") = matrix(blockSize, nEmbd)
stateDict("lm_head") = matrix(vocabSize, nEmbd)
for i <- 0 until nLayer do
stateDict(s"layer$i.attn_wq") = matrix(nEmbd, nEmbd)
stateDict(s"layer$i.attn_wk") = matrix(nEmbd, nEmbd)
stateDict(s"layer$i.attn_wv") = matrix(nEmbd, nEmbd)
stateDict(s"layer$i.attn_wo") = matrix(nEmbd, nEmbd)
stateDict(s"layer$i.mlp_fc1") = matrix(4 * nEmbd, nEmbd)
stateDict(s"layer$i.mlp_fc2") = matrix(nEmbd, 4 * nEmbd)
val params = stateDict.values.flatten.flatten.toSeq // flatten params into a single list[Value]
println(s"num params: ${params.size}")
// Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next.
// Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsNorm, no biases, GeLU -> ReLU
def linear(x: Seq[Value], w: Array[Array[Value]]): Seq[Value] =
w.map(wo => wo.zip(x).map(_ * _).reduce(_ + _)).toSeq
def softMax(logits: Seq[Value]): Seq[Value] =
val maxVal = logits.map(_.data).max
val exps = logits.map(v => (v - maxVal).exp)
val total = exps.reduce(_ + _)
exps.map(_ / total)
def rmsNorm(x: Seq[Value]): Seq[Value] =
val ms = x.map(xi => xi * xi).reduce(_ + _) / x.size.toDouble
val scale = (ms + 1e-5) ^ -0.5
x.map(_ * scale)
def gpt(tokenId: Int, posId: Int, keys: Array[mutable.ListBuffer[Seq[Value]]], values: Array[mutable.ListBuffer[Seq[Value]]]): Seq[Value] =
val tokEmbRow = stateDict("wte")(tokenId).toSeq // token embedding
val posEmbRow = stateDict("wpe")(posId).toSeq // position embedding
var x = tokEmbRow.zip(posEmbRow).map(_ + _) // joint token and position embedding
x = rmsNorm(x)
for li <- 0 until nLayer do
// 1) Multi-head attention block
val xResidual = x
x = rmsNorm(x)
val q = linear(x, stateDict(s"layer$li.attn_wq"))
val k = linear(x, stateDict(s"layer$li.attn_wk"))
val v = linear(x, stateDict(s"layer$li.attn_wv"))
keys(li).append(k)
values(li).append(v)
val xAttn = mutable.ListBuffer[Value]()
for h <- 0 until nHead do
val hs = h * headDim
val qH = q.slice(hs, hs + headDim)
val kH = keys(li).map(_.slice(hs, hs + headDim))
val vH = values(li).map(_.slice(hs, hs + headDim))
val attnLogits = kH.map(kt => qH.zip(kt).map(_ * _).reduce(_ + _) / math.sqrt(headDim)).toSeq
val attnWeights = softMax(attnLogits)
val headOut = (0 until headDim).map: j =>
attnWeights.zip(vH).map:
(att, vh) => att * vh(j)
.reduce(_ + _)
xAttn.appendAll(headOut)
x = linear(xAttn.toSeq, stateDict(s"layer$li.attn_wo"))
x = x.zip(xResidual).map(_ + _)
// 2) MLP block
val xResidual2 = x
x = rmsNorm(x)
x = linear(x, stateDict(s"layer$li.mlp_fc1"))
x = x.map(_.relu)
x = linear(x, stateDict(s"layer$li.mlp_fc2"))
x = x.zip(xResidual2).map(_ + _)
linear(x, stateDict("lm_head"))
// Let there be Adam, the blessed optimizer and its buffers
val learningRate = 0.01
val beta1 = 0.85
val beta2 = 0.99
val epsAdam = 1e-8
val m = Array.fill(params.size)(0.0) // first moment buffer
val v = Array.fill(params.size)(0.0) // second moment buffer
// Repeat in sequence
val numSteps = 1000 // number of training steps
for step <- 0 until numSteps do
// Take single document, tokenize it, surround it with BOS special token on both sides
val doc = shuffledDocs(step % shuffledDocs.size)
val tokens = BOS +: doc.map(ch => uChars.indexOf(ch)) :+ BOS
val n = math.min(blockSize, tokens.length - 1)
// Forward the token sequence through the model, building up the computation graph all the way to the loss.
val keys = Array.fill(nLayer)(mutable.ListBuffer[Seq[Value]]())
val values = Array.fill(nLayer)(mutable.ListBuffer[Seq[Value]]())
val losses = mutable.ListBuffer[Value]()
for posId <- 0 until n do
val tokenId = tokens(posId)
val targetId = tokens(posId + 1)
val logits = gpt(tokenId, posId, keys, values)
val probs = softMax(logits)
val lossT = -probs(targetId).log
losses.append(lossT)
val loss = losses.reduce(_ + _) * (1.0 / n) // final average loss over the document sequence. May yours be low.
// Backward the loss, calculating the gradients with respect to all model parameters.
loss.backward()
// Adam optimizer update: update the model parameters based on the corresponding gradients.
val lrT = learningRate * (1 - step.toDouble / numSteps) // linear learning rate decay
for (p, i) <- params.zipWithIndex do
m(i) = beta1 * m(i) + (1 - beta1) * p.grad
v(i) = beta2 * v(i) + (1 - beta2) * (p.grad * p.grad)
val mHat = m(i) / (1 - math.pow(beta1, step + 1))
val vHat = v(i) / (1 - math.pow(beta2, step + 1))
p.data -= lrT * mHat / (math.sqrt(vHat) + epsAdam)
p.grad = 0.0
if (step + 1) % 10 == 0 then
println(f"step ${step+1}%4d / $numSteps | loss ${loss.data}%.4f")
// Inference: may the model babble back to us
println("\n--- inference (new, hallucinated names) ---")
val temperature = 0.5 // in (0, 1], control the "creativity" of generated text, low to high
for sampleIdx <- 0 until 20 do
val keys = Array.fill(nLayer)(mutable.ListBuffer[Seq[Value]]())
val values = Array.fill(nLayer)(mutable.ListBuffer[Seq[Value]]())
var tokenId = BOS
val sample = mutable.StringBuilder()
var continue = true
var posId = 0
while continue && posId < blockSize do
val logits = gpt(tokenId, posId, keys, values)
val probsSeq = softMax(logits.map(l => l / temperature)).map(_.data)
tokenId = random.choices(0 until vocabSize, probsSeq)
if tokenId == BOS then
continue = false
else
sample.append(uChars(tokenId))
posId += 1
println(f"sample ${sampleIdx+1}%2d: ${sample.toString}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment