Skip to content

Instantly share code, notes, and snippets.

@bquast
Created February 15, 2026 08:35
Show Gist options
  • Select an option

  • Save bquast/ea3c6c0e17670915b793ea98065de60a to your computer and use it in GitHub Desktop.

Select an option

Save bquast/ea3c6c0e17670915b793ea98065de60a to your computer and use it in GitHub Desktop.
# --- Dataset Setup ---
input_file <- "input.txt"
if (!file.exists(input_file)) {
names_url <- "https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt"
download.file(names_url, input_file)
}
lines <- readLines(input_file, warn = FALSE)
docs <- lines[nchar(trimws(lines)) > 0]
set.seed(42)
docs <- sample(docs)
cat(sprintf("num docs: %d\n", length(docs)))
# --- Tokenizer ---
all_chars <- sort(unique(unlist(strsplit(paste(docs, collapse = ""), ""))))
uchars <- all_chars
BOS_TOKEN <- length(uchars) + 1
vocab_size <- length(uchars) + 1
cat(sprintf("vocab size: %d\n", vocab_size))
# --- Autograd (Functional Value System) ---
# We use an environment for grad so it can be updated by reference during backward pass
Value <- function(data, children = list(), local_grads = list()) {
v <- list(
data = data,
grad = new.env(parent = emptyenv()),
children = children,
local_grads = local_grads
)
v$grad$val <- 0
v
}
# Recursive backward pass (Non-OOP)
backward <- function(root) {
topo <- list()
visited <- list()
build_topo <- function(v) {
v_id <- format(v$grad) # Unique ID for the environment
if (is.null(visited[[v_id]])) {
visited[[v_id]] <<- TRUE
for (child in v$children) build_topo(child)
topo <<- c(topo, list(v))
}
}
build_topo(root)
root$grad$val <- 1
for (v in rev(topo)) {
if (length(v$children) > 0) {
for (i in seq_along(v$children)) {
child <- v$children[[i]]
local_grad <- v$local_grads[[i]]
child$grad$val <- child$grad$val + (local_grad * v$grad$val)
}
}
}
}
# --- Primitive Ops (Functional) ---
add_val <- function(a, b) {
if (!is.list(a)) a <- Value(a)
if (!is.list(b)) b <- Value(b)
Value(a$data + b$data, list(a, b), list(1, 1))
}
mul_val <- function(a, b) {
if (!is.list(a)) a <- Value(a)
if (!is.list(b)) b <- Value(b)
Value(a$data * b$data, list(a, b), list(b$data, a$data))
}
relu_val <- function(v) Value(max(0, v$data), list(v), list(as.numeric(v$data > 0)))
exp_val <- function(v) Value(exp(v$data), list(v), list(exp(v$data)))
log_val <- function(v) Value(log(v$data), list(v), list(1/v$data))
pow_val <- function(v, p) Value(v$data^p, list(v), list(p * v$data^(p-1)))
# --- Hyperparameters ---
n_embd <- 16
n_head <- 4
n_layer <- 1
block_size <- 16
head_dim <- n_embd / n_head
# --- Parameters ---
# Initializing weight matrices as lists of lists of Value items
state_dict <- list()
init_mat <- function(r, c) replicate(r, lapply(runif(c, -0.08, 0.08), Value), simplify = FALSE)
state_dict$wte <- init_mat(vocab_size, n_embd)
state_dict$wpe <- init_mat(block_size, n_embd)
state_dict$lm_head <- init_mat(vocab_size, n_embd)
for (i in 0:(n_layer-1)) {
state_dict[[paste0("l", i, ".aq")]] <- init_mat(n_embd, n_embd)
state_dict[[paste0("l", i, ".ak")]] <- init_mat(n_embd, n_embd)
state_dict[[paste0("l", i, ".av")]] <- init_mat(n_embd, n_embd)
state_dict[[paste0("l", i, ".ao")]] <- init_mat(n_embd, n_embd)
state_dict[[paste0("l", i, ".f1")]] <- init_mat(4 * n_embd, n_embd)
state_dict[[paste0("l", i, ".f2")]] <- init_mat(n_embd, 4 * n_embd)
}
params <- unlist(state_dict, recursive = TRUE)
# --- Training Logic ---
learning_rate <- 0.01
m_buf <- rep(0, length(params))
v_buf <- rep(0, length(params))
for (step in 1:600) {
# Tokenization
doc <- docs[((step - 1) %% length(docs)) + 1]
tokens <- c(BOS_TOKEN, sapply(unlist(strsplit(doc, "")), function(c) which(uchars == c)), BOS_TOKEN)
n <- min(block_size, length(tokens) - 1)
keys_cache <- replicate(n_layer, list(), simplify = FALSE)
values_cache <- replicate(n_layer, list(), simplify = FALSE)
losses <- list()
for (pos_id in 0:(n-1)) {
# 1. Embeddings
tid <- tokens[pos_id + 1]
pid <- pos_id + 1
x <- mapply(add_val, state_dict$wte[[tid]], state_dict$wpe[[pid]], SIMPLIFY = FALSE)
# RMSNorm (Inline repetition for simplicity)
ss <- Value(0); for(xi in x) ss <- add_val(ss, mul_val(xi, xi))
scale <- Value((ss$data / length(x) + 1e-5)^-0.5)
x <- lapply(x, function(xi) mul_val(xi, scale))
for (li in 0:(n_layer-1)) {
x_res <- x
# RMSNorm again
ss <- Value(0); for(xi in x) ss <- add_val(ss, mul_val(xi, xi))
scale <- Value((ss$data / length(x) + 1e-5)^-0.5)
x_norm <- lapply(x, function(xi) mul_val(xi, scale))
# Attention projections
q <- lapply(state_dict[[paste0("l", li, ".aq")]], function(row) {
s <- Value(0); for(j in seq_along(x_norm)) s <- add_val(s, mul_val(row[[j]], x_norm[[j]])); s
})
k <- lapply(state_dict[[paste0("l", li, ".ak")]], function(row) {
s <- Value(0); for(j in seq_along(x_norm)) s <- add_val(s, mul_val(row[[j]], x_norm[[j]])); s
})
v <- lapply(state_dict[[paste0("l", li, ".av")]], function(row) {
s <- Value(0); for(j in seq_along(x_norm)) s <- add_val(s, mul_val(row[[j]], x_norm[[j]])); s
})
keys_cache[[li+1]] <- c(keys_cache[[li+1]], list(k))
values_cache[[li+1]] <- c(values_cache[[li+1]], list(v))
# Multi-head Attention
x_attn <- list()
for (h in 0:(n_head-1)) {
h_idx <- (h * head_dim + 1):((h + 1) * head_dim)
qh <- q[h_idx]
# Self-attention weights
attn_logits <- lapply(seq_along(keys_cache[[li+1]]), function(t) {
kh <- keys_cache[[li+1]][[t]][h_idx]
dot <- Value(0); for(j in seq_along(qh)) dot <- add_val(dot, mul_val(qh[[j]], kh[[j]]))
mul_val(dot, 1/(head_dim^0.5))
})
# Softmax (inline)
max_v <- max(sapply(attn_logits, function(l) l$data))
exps <- lapply(attn_logits, function(l) exp_val(add_val(l, -max_v)))
sum_e <- Value(0); for(e in exps) sum_e <- add_val(sum_e, e)
probs <- lapply(exps, function(e) mul_val(e, pow_val(sum_e, -1)))
# Weighted sum
for (j in 1:head_dim) {
out_j <- Value(0)
for (t in seq_along(probs)) out_j <- add_val(out_j, mul_val(probs[[t]], values_cache[[li+1]][[t]][h_idx][[j]]))
x_attn <- c(x_attn, list(out_j))
}
}
# Attention Out + Residual
x_out <- lapply(state_dict[[paste0("l", li, ".ao")]], function(row) {
s <- Value(0); for(j in seq_along(x_attn)) s <- add_val(s, mul_val(row[[j]], x_attn[[j]])); s
})
x <- mapply(add_val, x_out, x_res, SIMPLIFY = FALSE)
# MLP
x_res <- x
ss <- Value(0); for(xi in x) ss <- add_val(ss, mul_val(xi, xi))
scale <- Value((ss$data / length(x) + 1e-5)^-0.5)
x_norm <- lapply(x, function(xi) mul_val(xi, scale))
f1 <- lapply(state_dict[[paste0("l", li, ".f1")]], function(row) {
s <- Value(0); for(j in seq_along(x_norm)) s <- add_val(s, mul_val(row[[j]], x_norm[[j]])); s
})
f1 <- lapply(f1, relu_val)
f2 <- lapply(state_dict[[paste0("l", li, ".f2")]], function(row) {
s <- Value(0); for(j in seq_along(f1)) s <- add_val(s, mul_val(row[[j]], f1[[j]])); s
})
x <- mapply(add_val, f2, x_res, SIMPLIFY = FALSE)
}
# Classifier
logits <- lapply(state_dict$lm_head, function(row) {
s <- Value(0); for(j in seq_along(x)) s <- add_val(s, mul_val(row[[j]], x[[j]])); s
})
# Loss
target_id <- tokens[pos_id + 2]
max_l <- max(sapply(logits, function(l) l$data))
exps_l <- lapply(logits, function(l) exp_val(add_val(l, -max_l)))
sum_el <- Value(0); for(e in exps_l) sum_el <- add_val(sum_el, e)
prob_target <- mul_val(exps_l[[target_id]], pow_val(sum_el, -1))
losses <- c(losses, list(mul_val(log_val(prob_target), -1)))
}
# Average Loss
total_loss <- Value(0); for(l in losses) total_loss <- add_val(total_loss, l)
final_loss <- mul_val(total_loss, 1/n)
# Optimization
backward(final_loss)
lr_t <- learning_rate * (1 - (step-1)/1000)
for (i in seq_along(params)) {
p <- params[[i]]
grad <- p$grad$val
m_buf[i] <- 0.85 * m_buf[i] + 0.15 * grad
v_buf[i] <- 0.99 * v_buf[i] + 0.01 * (grad^2)
m_hat <- m_buf[i] / (1 - 0.85^step)
v_hat <- v_buf[i] / (1 - 0.99^step)
p$data <- p$data - (lr_t * m_hat / (sqrt(v_hat) + 1e-8))
p$grad$val <- 0
}
if (step %% 10 == 0) cat(sprintf("step %4d | loss %.4f\n", step, final_loss$data))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment