Skip to content

Instantly share code, notes, and snippets.

@lexpank
Created March 17, 2026 22:48
Show Gist options
  • Select an option

  • Save lexpank/b4074152238b21a49948acf0417c9e6f to your computer and use it in GitHub Desktop.

Select an option

Save lexpank/b4074152238b21a49948acf0417c9e6f to your computer and use it in GitHub Desktop.
CUDA C microgpt
// Nobody asked for microgpt in CUDA C, which is exactly why it had to happen.
// It is no longer particularly micro.
// It is, however, stupidly fast.
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define N_LAYER 1
#define N_EMBD 16
#define BLOCK_SIZE 16
#define N_HEAD 4
#define HEAD_DIM (N_EMBD / N_HEAD)
#define MLP_HIDDEN (4 * N_EMBD)
#define INIT_STD 0.08f
#define RMS_EPS 1e-5f
#define LEARNING_RATE 0.01f
#define BETA1 0.85f
#define BETA2 0.99f
#define ADAM_EPS 1e-8f
#define CHECK_CUDA(call) \
do { \
cudaError_t err__ = (call); \
if (err__ != cudaSuccess) { \
fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \
cudaGetErrorString(err__)); \
exit(1); \
} \
} while (0)
#define CHECK_CUBLAS(call) \
do { \
cublasStatus_t status__ = (call); \
if (status__ != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error at %s:%d: %d\n", __FILE__, __LINE__, \
(int)status__); \
exit(1); \
} \
} while (0)
typedef struct {
char **items;
int count;
int capacity;
} StringArray;
typedef struct {
StringArray docs;
int vocab_size;
int bos_id;
int char_to_id[256];
unsigned char id_to_char[256];
} Dataset;
typedef struct {
const char *name;
int rows;
int cols;
float *data;
float *grad;
float *m;
float *v;
} Param;
typedef struct {
float *x_in;
float *attn_norm;
float *q;
float *k;
float *v;
float *attn_logits;
float *attn_weights;
float *attn_out;
float *attn_proj;
float *mlp_in;
float *mlp_norm;
float *fc1;
float *relu;
float *fc2;
float *x_out;
} LayerCache;
typedef struct {
int token_id;
int target_id;
float *embed_in;
float *x0;
float *logits;
float *probs;
LayerCache layers[N_LAYER];
} PositionCache;
typedef struct {
PositionCache pos[BLOCK_SIZE];
float *keys[N_LAYER];
float *values[N_LAYER];
float *grad_keys[N_LAYER];
float *grad_values[N_LAYER];
float *grad_k_accum[N_LAYER][BLOCK_SIZE];
float *grad_v_accum[N_LAYER][BLOCK_SIZE];
} TrainCache;
typedef struct {
float *grad_logits;
float *grad_x;
float *grad_x_in;
float *grad_mlp_in;
float *grad_relu;
float *grad_fc1;
float *grad_mlp_norm;
float *grad_norm_tmp;
float *grad_attn_out;
float *grad_q;
float *grad_attn_norm;
float *grad_embed_in;
} Scratch;
typedef struct {
Param wte;
Param wpe;
Param lm_head;
Param attn_wq[N_LAYER];
Param attn_wk[N_LAYER];
Param attn_wv[N_LAYER];
Param attn_wo[N_LAYER];
Param mlp_fc1[N_LAYER];
Param mlp_fc2[N_LAYER];
} Model;
typedef struct {
int num_steps;
int num_samples;
float temperature;
const char *input_path;
} Options;
typedef struct {
float *base;
size_t count;
size_t offset;
} FloatArena;
static cublasHandle_t g_cublas = NULL;
static __global__ void vec_add_kernel(const float *a, const float *b, float *out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
out[idx] = a[idx] + b[idx];
}
}
static __global__ void vec_add_inplace_kernel(float *dst, const float *src, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
dst[idx] += src[idx];
}
}
static __global__ void relu_forward_kernel(const float *x, float *y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] > 0.0f ? x[idx] : 0.0f;
}
}
static __global__ void relu_backward_kernel(const float *x, const float *grad_y, float *grad_x, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
grad_x[idx] = x[idx] > 0.0f ? grad_y[idx] : 0.0f;
}
}
static __global__ void rmsnorm_forward_kernel(const float *x, float *y, int n) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
float mean_square = 0.0f;
for (int i = 0; i < n; ++i) {
mean_square += x[i] * x[i];
}
mean_square /= (float)n;
float scale = rsqrtf(mean_square + RMS_EPS);
for (int i = 0; i < n; ++i) {
y[i] = x[i] * scale;
}
}
}
static __global__ void rmsnorm_backward_kernel(const float *x, const float *grad_y, float *grad_x, int n) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
float mean_square = 0.0f;
float dot = 0.0f;
for (int i = 0; i < n; ++i) {
mean_square += x[i] * x[i];
dot += grad_y[i] * x[i];
}
mean_square /= (float)n;
float scale = rsqrtf(mean_square + RMS_EPS);
float scale_cubed = scale * scale * scale;
for (int i = 0; i < n; ++i) {
grad_x[i] = grad_y[i] * scale - x[i] * dot * scale_cubed / (float)n;
}
}
}
static __global__ void scaled_softmax_kernel(const float *logits, float *probs, float scale, int n) {
__shared__ float scratch[256];
int tid = threadIdx.x;
float local_max = -1e30f;
for (int i = tid; i < n; i += blockDim.x) {
float v = logits[i] * scale;
if (v > local_max) {
local_max = v;
}
}
scratch[tid] = local_max;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride && scratch[tid + stride] > scratch[tid]) {
scratch[tid] = scratch[tid + stride];
}
__syncthreads();
}
float max_val = scratch[0];
float local_sum = 0.0f;
for (int i = tid; i < n; i += blockDim.x) {
float e = expf(logits[i] * scale - max_val);
probs[i] = e;
local_sum += e;
}
scratch[tid] = local_sum;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
scratch[tid] += scratch[tid + stride];
}
__syncthreads();
}
float inv_sum = 1.0f / scratch[0];
for (int i = tid; i < n; i += blockDim.x) {
probs[i] *= inv_sum;
}
}
static __global__ void cross_entropy_backward_kernel(const float *probs, float *grad_logits, int target_id, float loss_scale, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float grad = probs[idx] * loss_scale;
if (idx == target_id) {
grad -= loss_scale;
}
grad_logits[idx] = grad;
}
}
static __global__ void adam_update_kernel(
float *data,
float *grad,
float *m,
float *v,
float lr_t,
float beta1_corr,
float beta2_corr,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float g = grad[idx];
float m_new = BETA1 * m[idx] + (1.0f - BETA1) * g;
float v_new = BETA2 * v[idx] + (1.0f - BETA2) * g * g;
m[idx] = m_new;
v[idx] = v_new;
float m_hat = m_new / beta1_corr;
float v_hat = v_new / beta2_corr;
data[idx] -= lr_t * m_hat / (sqrtf(v_hat) + ADAM_EPS);
grad[idx] = 0.0f;
}
}
static __global__ void row_add_kernel(float *row_grad, const float *grad, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
row_grad[idx] += grad[idx];
}
}
static __global__ void attention_forward_kernel(
const float *q,
const float *keys,
const float *values,
float *attn_logits,
float *attn_weights,
float *attn_out,
int pos_id
) {
int h = threadIdx.x;
if (h >= N_HEAD) {
return;
}
int hs = h * HEAD_DIM;
float logits_local[BLOCK_SIZE];
float weights_local[BLOCK_SIZE];
float max_val = -1e30f;
float sum = 0.0f;
float scale = 1.0f / sqrtf((float)HEAD_DIM);
for (int t = 0; t <= pos_id; ++t) {
float dot = 0.0f;
const float *k_t = keys + (size_t)t * N_EMBD + hs;
for (int j = 0; j < HEAD_DIM; ++j) {
dot += q[hs + j] * k_t[j];
}
logits_local[t] = dot * scale;
if (logits_local[t] > max_val) {
max_val = logits_local[t];
}
}
for (int t = 0; t <= pos_id; ++t) {
float w = expf(logits_local[t] - max_val);
weights_local[t] = w;
sum += w;
}
float inv_sum = 1.0f / sum;
for (int t = 0; t <= pos_id; ++t) {
float w = weights_local[t] * inv_sum;
attn_logits[h * BLOCK_SIZE + t] = logits_local[t];
attn_weights[h * BLOCK_SIZE + t] = w;
weights_local[t] = w;
}
for (int j = 0; j < HEAD_DIM; ++j) {
float out = 0.0f;
for (int t = 0; t <= pos_id; ++t) {
const float *v_t = values + (size_t)t * N_EMBD + hs;
out += weights_local[t] * v_t[j];
}
attn_out[hs + j] = out;
}
}
static __global__ void attention_backward_kernel(
const float *q,
const float *keys,
const float *values,
const float *attn_weights,
const float *grad_attn_out,
float *grad_q,
float *grad_keys,
float *grad_values,
int pos_id
) {
int h = threadIdx.x;
if (h >= N_HEAD) {
return;
}
int hs = h * HEAD_DIM;
float grad_w[BLOCK_SIZE];
float grad_q_local[HEAD_DIM];
float weighted_sum = 0.0f;
float scale = 1.0f / sqrtf((float)HEAD_DIM);
for (int j = 0; j < HEAD_DIM; ++j) {
grad_q_local[j] = 0.0f;
}
for (int t = 0; t <= pos_id; ++t) {
const float *v_t = values + (size_t)t * N_EMBD + hs;
float dot = 0.0f;
float weight = attn_weights[h * BLOCK_SIZE + t];
float *grad_v_t = grad_values + (size_t)t * N_EMBD + hs;
for (int j = 0; j < HEAD_DIM; ++j) {
float go = grad_attn_out[hs + j];
dot += go * v_t[j];
grad_v_t[j] += go * weight;
}
grad_w[t] = dot;
weighted_sum += dot * weight;
}
for (int t = 0; t <= pos_id; ++t) {
const float *k_t = keys + (size_t)t * N_EMBD + hs;
float *grad_k_t = grad_keys + (size_t)t * N_EMBD + hs;
float weight = attn_weights[h * BLOCK_SIZE + t];
float grad_logit = weight * (grad_w[t] - weighted_sum);
for (int j = 0; j < HEAD_DIM; ++j) {
grad_q_local[j] += grad_logit * k_t[j] * scale;
grad_k_t[j] += grad_logit * q[hs + j] * scale;
}
}
for (int j = 0; j < HEAD_DIM; ++j) {
grad_q[hs + j] = grad_q_local[j];
}
}
static void zero_device(float *x, int n);
static void copy_device(float *dst, const float *src, int n);
static void copy_to_host(float *dst, const float *src, int n);
static void copy_to_device(float *dst, const float *src, int n);
static void launch_vec_add(const float *a, const float *b, float *out, int n) {
int threads = n < 256 ? n : 256;
vec_add_kernel<<<1, threads>>>(a, b, out, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_vec_add_inplace(float *dst, const float *src, int n) {
int threads = n < 256 ? n : 256;
vec_add_inplace_kernel<<<1, threads>>>(dst, src, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_relu_forward(const float *x, float *y, int n) {
int threads = n < 256 ? n : 256;
relu_forward_kernel<<<1, threads>>>(x, y, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_relu_backward(const float *x, const float *grad_y, float *grad_x, int n) {
int threads = n < 256 ? n : 256;
relu_backward_kernel<<<1, threads>>>(x, grad_y, grad_x, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_matvec(const float *w, const float *x, float *y, int rows, int cols) {
const float alpha = 1.0f;
const float beta = 0.0f;
CHECK_CUBLAS(cublasSgemv(
g_cublas,
CUBLAS_OP_T,
cols,
rows,
&alpha,
w,
cols,
x,
1,
&beta,
y,
1
));
}
static void launch_linear_backward(
const float *w,
const float *x,
const float *grad_out,
float *grad_w,
float *grad_x,
int rows,
int cols
) {
const float alpha = 1.0f;
const float beta = 1.0f;
CHECK_CUBLAS(cublasSger(
g_cublas,
cols,
rows,
&alpha,
x,
1,
grad_out,
1,
grad_w,
cols
));
CHECK_CUBLAS(cublasSgemv(
g_cublas,
CUBLAS_OP_N,
cols,
rows,
&alpha,
w,
cols,
grad_out,
1,
&beta,
grad_x,
1
));
}
static void launch_rmsnorm_forward(const float *x, float *y, int n) {
rmsnorm_forward_kernel<<<1, 1>>>(x, y, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_rmsnorm_backward(const float *x, const float *grad_y, float *grad_x, int n) {
rmsnorm_backward_kernel<<<1, 1>>>(x, grad_y, grad_x, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_scaled_softmax(const float *logits, float *probs, float scale, int n) {
scaled_softmax_kernel<<<1, 256>>>(logits, probs, scale, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_cross_entropy_backward(const float *probs, float *grad_logits, int target_id, float loss_scale, int n) {
int threads = n < 256 ? n : 256;
cross_entropy_backward_kernel<<<1, threads>>>(probs, grad_logits, target_id, loss_scale, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_attention_forward(
const float *q,
const float *keys,
const float *values,
float *attn_logits,
float *attn_weights,
float *attn_out,
int pos_id
) {
attention_forward_kernel<<<1, N_HEAD>>>(q, keys, values, attn_logits, attn_weights, attn_out, pos_id);
CHECK_CUDA(cudaGetLastError());
}
static void launch_attention_backward(
const float *q,
const float *keys,
const float *values,
const float *attn_weights,
const float *grad_attn_out,
float *grad_q,
float *grad_keys,
float *grad_values,
int pos_id
) {
attention_backward_kernel<<<1, N_HEAD>>>(
q,
keys,
values,
attn_weights,
grad_attn_out,
grad_q,
grad_keys,
grad_values,
pos_id
);
CHECK_CUDA(cudaGetLastError());
}
static void launch_adam_update(float *data, float *grad, float *m, float *v, float lr_t, float beta1_corr, float beta2_corr, int n) {
int threads = n < 256 ? n : 256;
int blocks = (n + threads - 1) / threads;
adam_update_kernel<<<blocks, threads>>>(data, grad, m, v, lr_t, beta1_corr, beta2_corr, n);
CHECK_CUDA(cudaGetLastError());
}
static void launch_row_add(float *row_grad, const float *grad, int n) {
int threads = n < 256 ? n : 256;
row_add_kernel<<<1, threads>>>(row_grad, grad, n);
CHECK_CUDA(cudaGetLastError());
}
static void *xmalloc(size_t size) {
void *ptr = malloc(size);
if (ptr == NULL) {
fprintf(stderr, "malloc failed\n");
exit(1);
}
return ptr;
}
static char *xstrdup(const char *src) {
size_t len = strlen(src);
char *dst = (char *)xmalloc(len + 1);
memcpy(dst, src, len + 1);
return dst;
}
static void init_arena(FloatArena *arena, size_t count) {
arena->count = count;
arena->offset = 0;
CHECK_CUDA(cudaMalloc((void **)&arena->base, count * sizeof(float)));
CHECK_CUDA(cudaMemset(arena->base, 0, count * sizeof(float)));
}
static float *arena_alloc(FloatArena *arena, size_t count) {
if (arena->offset + count > arena->count) {
fprintf(stderr, "arena overflow: requested %zu floats, %zu remaining\n",
count, arena->count - arena->offset);
exit(1);
}
float *ptr = arena->base + arena->offset;
arena->offset += count;
CHECK_CUDA(cudaMemset(ptr, 0, count * sizeof(float)));
return ptr;
}
static void init_string_array(StringArray *arr) {
arr->items = NULL;
arr->count = 0;
arr->capacity = 0;
}
static void push_string(StringArray *arr, const char *value) {
if (arr->count == arr->capacity) {
int new_capacity = arr->capacity == 0 ? 64 : arr->capacity * 2;
char **new_items = (char **)realloc(arr->items, (size_t)new_capacity * sizeof(char *));
if (new_items == NULL) {
fprintf(stderr, "realloc failed\n");
exit(1);
}
arr->items = new_items;
arr->capacity = new_capacity;
}
arr->items[arr->count++] = xstrdup(value);
}
static void shuffle_docs(StringArray *arr) {
for (int i = arr->count - 1; i > 0; --i) {
int j = rand() % (i + 1);
char *tmp = arr->items[i];
arr->items[i] = arr->items[j];
arr->items[j] = tmp;
}
}
static void load_dataset(Dataset *dataset, const char *input_path) {
init_string_array(&dataset->docs);
for (int i = 0; i < 256; ++i) {
dataset->char_to_id[i] = -1;
dataset->id_to_char[i] = 0;
}
FILE *fp = fopen(input_path, "r");
if (fp == NULL) {
fprintf(stderr, "could not open %s\n", input_path);
exit(1);
}
char *line = NULL;
size_t cap = 0;
ssize_t read = 0;
while ((read = getline(&line, &cap, fp)) != -1) {
while (read > 0 && (line[read - 1] == '\n' || line[read - 1] == '\r')) {
line[--read] = '\0';
}
if (read > 0) {
push_string(&dataset->docs, line);
}
}
free(line);
fclose(fp);
shuffle_docs(&dataset->docs);
int vocab = 0;
int seen[256] = {0};
for (int d = 0; d < dataset->docs.count; ++d) {
const unsigned char *doc = (const unsigned char *)dataset->docs.items[d];
for (int i = 0; doc[i] != '\0'; ++i) {
seen[doc[i]] = 1;
}
}
for (int c = 0; c < 256; ++c) {
if (seen[c]) {
dataset->char_to_id[c] = vocab;
dataset->id_to_char[vocab] = (unsigned char)c;
++vocab;
}
}
dataset->bos_id = vocab;
dataset->vocab_size = vocab + 1;
}
static float rand_uniform(void) {
return ((float)rand() + 1.0f) / ((float)RAND_MAX + 2.0f);
}
static float rand_normal(float stddev) {
float u1 = rand_uniform();
float u2 = rand_uniform();
float mag = sqrtf(-2.0f * logf(u1));
float phase = 2.0f * (float)M_PI * u2;
return stddev * mag * cosf(phase);
}
static void init_param(Param *param, FloatArena *arena, const char *name, int rows, int cols) {
size_t count = (size_t)rows * (size_t)cols;
float *host_data = (float *)xmalloc(count * sizeof(float));
param->name = name;
param->rows = rows;
param->cols = cols;
param->data = arena_alloc(arena, count);
param->grad = arena_alloc(arena, count);
param->m = arena_alloc(arena, count);
param->v = arena_alloc(arena, count);
for (size_t i = 0; i < count; ++i) {
host_data[i] = rand_normal(INIT_STD);
}
copy_to_device(param->data, host_data, (int)count);
free(host_data);
}
static size_t param_count(const Param *param) {
return (size_t)param->rows * (size_t)param->cols;
}
static size_t param_storage_count(int vocab_size) {
size_t count = 0;
count += 4 * (size_t)vocab_size * N_EMBD;
count += 4 * (size_t)BLOCK_SIZE * N_EMBD;
count += 4 * (size_t)vocab_size * N_EMBD;
for (int layer = 0; layer < N_LAYER; ++layer) {
count += 4 * (size_t)N_EMBD * N_EMBD;
count += 4 * (size_t)N_EMBD * N_EMBD;
count += 4 * (size_t)N_EMBD * N_EMBD;
count += 4 * (size_t)N_EMBD * N_EMBD;
count += 4 * (size_t)MLP_HIDDEN * N_EMBD;
count += 4 * (size_t)N_EMBD * MLP_HIDDEN;
}
return count;
}
static void init_model(Model *model, FloatArena *arena, int vocab_size) {
init_param(&model->wte, arena, "wte", vocab_size, N_EMBD);
init_param(&model->wpe, arena, "wpe", BLOCK_SIZE, N_EMBD);
init_param(&model->lm_head, arena, "lm_head", vocab_size, N_EMBD);
for (int layer = 0; layer < N_LAYER; ++layer) {
init_param(&model->attn_wq[layer], arena, "attn_wq", N_EMBD, N_EMBD);
init_param(&model->attn_wk[layer], arena, "attn_wk", N_EMBD, N_EMBD);
init_param(&model->attn_wv[layer], arena, "attn_wv", N_EMBD, N_EMBD);
init_param(&model->attn_wo[layer], arena, "attn_wo", N_EMBD, N_EMBD);
init_param(&model->mlp_fc1[layer], arena, "mlp_fc1", MLP_HIDDEN, N_EMBD);
init_param(&model->mlp_fc2[layer], arena, "mlp_fc2", N_EMBD, MLP_HIDDEN);
}
}
static size_t total_params(const Model *model) {
size_t count = 0;
count += param_count(&model->wte);
count += param_count(&model->wpe);
count += param_count(&model->lm_head);
for (int layer = 0; layer < N_LAYER; ++layer) {
count += param_count(&model->attn_wq[layer]);
count += param_count(&model->attn_wk[layer]);
count += param_count(&model->attn_wv[layer]);
count += param_count(&model->attn_wo[layer]);
count += param_count(&model->mlp_fc1[layer]);
count += param_count(&model->mlp_fc2[layer]);
}
return count;
}
static size_t layer_cache_storage_count(void) {
size_t count = 0;
count += N_EMBD;
count += N_EMBD;
count += N_EMBD;
count += (size_t)N_HEAD * BLOCK_SIZE;
count += (size_t)N_HEAD * BLOCK_SIZE;
count += N_EMBD;
count += N_EMBD;
count += N_EMBD;
count += N_EMBD;
count += MLP_HIDDEN;
count += MLP_HIDDEN;
count += N_EMBD;
count += N_EMBD;
return count;
}
static void alloc_layer_cache(LayerCache *cache, FloatArena *arena) {
cache->x_in = arena_alloc(arena, N_EMBD);
cache->attn_norm = arena_alloc(arena, N_EMBD);
cache->q = arena_alloc(arena, N_EMBD);
cache->k = NULL;
cache->v = NULL;
cache->attn_logits = arena_alloc(arena, (size_t)N_HEAD * BLOCK_SIZE);
cache->attn_weights = arena_alloc(arena, (size_t)N_HEAD * BLOCK_SIZE);
cache->attn_out = arena_alloc(arena, N_EMBD);
cache->attn_proj = arena_alloc(arena, N_EMBD);
cache->mlp_in = arena_alloc(arena, N_EMBD);
cache->mlp_norm = arena_alloc(arena, N_EMBD);
cache->fc1 = arena_alloc(arena, MLP_HIDDEN);
cache->relu = arena_alloc(arena, MLP_HIDDEN);
cache->fc2 = arena_alloc(arena, N_EMBD);
cache->x_out = arena_alloc(arena, N_EMBD);
}
static size_t train_cache_storage_count(int vocab_size) {
size_t count = 0;
for (int layer = 0; layer < N_LAYER; ++layer) {
count += (size_t)BLOCK_SIZE * N_EMBD;
count += (size_t)BLOCK_SIZE * N_EMBD;
count += (size_t)BLOCK_SIZE * N_EMBD;
count += (size_t)BLOCK_SIZE * N_EMBD;
}
for (int pos = 0; pos < BLOCK_SIZE; ++pos) {
count += N_EMBD;
count += N_EMBD;
count += vocab_size;
count += vocab_size;
for (int layer = 0; layer < N_LAYER; ++layer) {
count += layer_cache_storage_count();
}
}
return count;
}
static void init_train_cache(TrainCache *cache, int vocab_size, FloatArena *arena) {
for (int layer = 0; layer < N_LAYER; ++layer) {
cache->keys[layer] = arena_alloc(arena, (size_t)BLOCK_SIZE * N_EMBD);
cache->values[layer] = arena_alloc(arena, (size_t)BLOCK_SIZE * N_EMBD);
cache->grad_keys[layer] = arena_alloc(arena, (size_t)BLOCK_SIZE * N_EMBD);
cache->grad_values[layer] = arena_alloc(arena, (size_t)BLOCK_SIZE * N_EMBD);
}
for (int pos = 0; pos < BLOCK_SIZE; ++pos) {
cache->pos[pos].embed_in = arena_alloc(arena, N_EMBD);
cache->pos[pos].x0 = arena_alloc(arena, N_EMBD);
cache->pos[pos].logits = arena_alloc(arena, vocab_size);
cache->pos[pos].probs = arena_alloc(arena, vocab_size);
for (int layer = 0; layer < N_LAYER; ++layer) {
alloc_layer_cache(&cache->pos[pos].layers[layer], arena);
cache->pos[pos].layers[layer].k = cache->keys[layer] + (size_t)pos * N_EMBD;
cache->pos[pos].layers[layer].v = cache->values[layer] + (size_t)pos * N_EMBD;
cache->grad_k_accum[layer][pos] = cache->grad_keys[layer] + (size_t)pos * N_EMBD;
cache->grad_v_accum[layer][pos] = cache->grad_values[layer] + (size_t)pos * N_EMBD;
}
}
}
static size_t scratch_storage_count(int vocab_size) {
size_t count = 0;
count += vocab_size;
count += N_EMBD;
count += N_EMBD;
count += N_EMBD;
count += MLP_HIDDEN;
count += MLP_HIDDEN;
count += N_EMBD;
count += N_EMBD;
count += N_EMBD;
count += N_EMBD;
count += N_EMBD;
count += N_EMBD;
return count;
}
static void init_scratch(Scratch *scratch, int vocab_size, FloatArena *arena) {
scratch->grad_logits = arena_alloc(arena, vocab_size);
scratch->grad_x = arena_alloc(arena, N_EMBD);
scratch->grad_x_in = arena_alloc(arena, N_EMBD);
scratch->grad_mlp_in = arena_alloc(arena, N_EMBD);
scratch->grad_relu = arena_alloc(arena, MLP_HIDDEN);
scratch->grad_fc1 = arena_alloc(arena, MLP_HIDDEN);
scratch->grad_mlp_norm = arena_alloc(arena, N_EMBD);
scratch->grad_norm_tmp = arena_alloc(arena, N_EMBD);
scratch->grad_attn_out = arena_alloc(arena, N_EMBD);
scratch->grad_q = arena_alloc(arena, N_EMBD);
scratch->grad_attn_norm = arena_alloc(arena, N_EMBD);
scratch->grad_embed_in = arena_alloc(arena, N_EMBD);
}
static void zero_device(float *x, int n) {
CHECK_CUDA(cudaMemset(x, 0, (size_t)n * sizeof(float)));
}
static void copy_device(float *dst, const float *src, int n) {
CHECK_CUDA(cudaMemcpy(dst, src, (size_t)n * sizeof(float), cudaMemcpyDeviceToDevice));
}
static void copy_to_host(float *dst, const float *src, int n) {
CHECK_CUDA(cudaMemcpy(dst, src, (size_t)n * sizeof(float), cudaMemcpyDeviceToHost));
}
static void copy_to_device(float *dst, const float *src, int n) {
CHECK_CUDA(cudaMemcpy(dst, src, (size_t)n * sizeof(float), cudaMemcpyHostToDevice));
}
static int tokenize_doc(const Dataset *dataset, const char *doc, int *tokens) {
int count = 0;
tokens[count++] = dataset->bos_id;
for (int i = 0; doc[i] != '\0' && count < BLOCK_SIZE + 1; ++i) {
tokens[count++] = dataset->char_to_id[(unsigned char)doc[i]];
}
if (count < BLOCK_SIZE + 2) {
tokens[count++] = dataset->bos_id;
}
return count;
}
static const float *layer_input_ptr(const PositionCache *pos, int layer) {
if (layer == 0) {
return pos->x0;
}
return pos->layers[layer - 1].x_out;
}
static const float *final_hidden_ptr(const PositionCache *pos) {
if (N_LAYER == 0) {
return pos->x0;
}
return pos->layers[N_LAYER - 1].x_out;
}
static void forward_position(Model *model, TrainCache *cache, int vocab_size, int token_id, int pos_id) {
PositionCache *pos = &cache->pos[pos_id];
pos->token_id = token_id;
const float *tok_emb = model->wte.data + (size_t)token_id * N_EMBD;
const float *pos_emb = model->wpe.data + (size_t)pos_id * N_EMBD;
launch_vec_add(tok_emb, pos_emb, pos->embed_in, N_EMBD);
launch_rmsnorm_forward(pos->embed_in, pos->x0, N_EMBD);
for (int layer = 0; layer < N_LAYER; ++layer) {
LayerCache *lc = &pos->layers[layer];
const float *x_in = layer_input_ptr(pos, layer);
copy_device(lc->x_in, x_in, N_EMBD);
launch_rmsnorm_forward(lc->x_in, lc->attn_norm, N_EMBD);
launch_matvec(model->attn_wq[layer].data, lc->attn_norm, lc->q, N_EMBD, N_EMBD);
launch_matvec(model->attn_wk[layer].data, lc->attn_norm, lc->k, N_EMBD, N_EMBD);
launch_matvec(model->attn_wv[layer].data, lc->attn_norm, lc->v, N_EMBD, N_EMBD);
launch_attention_forward(
lc->q,
cache->keys[layer],
cache->values[layer],
lc->attn_logits,
lc->attn_weights,
lc->attn_out,
pos_id
);
launch_matvec(model->attn_wo[layer].data, lc->attn_out, lc->attn_proj, N_EMBD, N_EMBD);
launch_vec_add(lc->attn_proj, lc->x_in, lc->mlp_in, N_EMBD);
launch_rmsnorm_forward(lc->mlp_in, lc->mlp_norm, N_EMBD);
launch_matvec(model->mlp_fc1[layer].data, lc->mlp_norm, lc->fc1, MLP_HIDDEN, N_EMBD);
launch_relu_forward(lc->fc1, lc->relu, MLP_HIDDEN);
launch_matvec(model->mlp_fc2[layer].data, lc->relu, lc->fc2, N_EMBD, MLP_HIDDEN);
launch_vec_add(lc->fc2, lc->mlp_in, lc->x_out, N_EMBD);
}
launch_matvec(model->lm_head.data, final_hidden_ptr(pos), pos->logits, vocab_size, N_EMBD);
}
static float forward_sequence(Model *model, TrainCache *cache, const Dataset *dataset, const int *tokens, int n) {
float total_loss = 0.0f;
float *host_probs = (float *)xmalloc((size_t)dataset->vocab_size * sizeof(float));
for (int pos = 0; pos < n; ++pos) {
PositionCache *pc = &cache->pos[pos];
pc->target_id = tokens[pos + 1];
forward_position(model, cache, dataset->vocab_size, tokens[pos], pos);
launch_scaled_softmax(pc->logits, pc->probs, 1.0f, dataset->vocab_size);
copy_to_host(host_probs, pc->probs, dataset->vocab_size);
float prob = host_probs[pc->target_id];
if (prob < 1e-20f) {
prob = 1e-20f;
}
total_loss += -logf(prob);
}
free(host_probs);
return total_loss / (float)n;
}
static void zero_attention_accums(TrainCache *cache) {
for (int layer = 0; layer < N_LAYER; ++layer) {
CHECK_CUDA(cudaMemset(cache->grad_keys[layer], 0, (size_t)BLOCK_SIZE * N_EMBD * sizeof(float)));
CHECK_CUDA(cudaMemset(cache->grad_values[layer], 0, (size_t)BLOCK_SIZE * N_EMBD * sizeof(float)));
}
}
static void backward_sequence(Model *model, TrainCache *cache, Scratch *scratch, const Dataset *dataset, int n) {
zero_attention_accums(cache);
float loss_scale = 1.0f / (float)n;
for (int pos = n - 1; pos >= 0; --pos) {
PositionCache *pc = &cache->pos[pos];
launch_cross_entropy_backward(pc->probs, scratch->grad_logits, pc->target_id, loss_scale, dataset->vocab_size);
zero_device(scratch->grad_x, N_EMBD);
launch_linear_backward(
model->lm_head.data,
final_hidden_ptr(pc),
scratch->grad_logits,
model->lm_head.grad,
scratch->grad_x,
dataset->vocab_size,
N_EMBD
);
for (int layer = N_LAYER - 1; layer >= 0; --layer) {
LayerCache *lc = &pc->layers[layer];
copy_device(scratch->grad_mlp_in, scratch->grad_x, N_EMBD);
zero_device(scratch->grad_relu, MLP_HIDDEN);
launch_linear_backward(
model->mlp_fc2[layer].data,
lc->relu,
scratch->grad_x,
model->mlp_fc2[layer].grad,
scratch->grad_relu,
N_EMBD,
MLP_HIDDEN
);
zero_device(scratch->grad_fc1, MLP_HIDDEN);
launch_relu_backward(lc->fc1, scratch->grad_relu, scratch->grad_fc1, MLP_HIDDEN);
zero_device(scratch->grad_mlp_norm, N_EMBD);
launch_linear_backward(
model->mlp_fc1[layer].data,
lc->mlp_norm,
scratch->grad_fc1,
model->mlp_fc1[layer].grad,
scratch->grad_mlp_norm,
MLP_HIDDEN,
N_EMBD
);
zero_device(scratch->grad_norm_tmp, N_EMBD);
launch_rmsnorm_backward(lc->mlp_in, scratch->grad_mlp_norm, scratch->grad_norm_tmp, N_EMBD);
launch_vec_add_inplace(scratch->grad_mlp_in, scratch->grad_norm_tmp, N_EMBD);
copy_device(scratch->grad_x_in, scratch->grad_mlp_in, N_EMBD);
zero_device(scratch->grad_attn_out, N_EMBD);
launch_linear_backward(
model->attn_wo[layer].data,
lc->attn_out,
scratch->grad_mlp_in,
model->attn_wo[layer].grad,
scratch->grad_attn_out,
N_EMBD,
N_EMBD
);
zero_device(scratch->grad_attn_norm, N_EMBD);
launch_attention_backward(
lc->q,
cache->keys[layer],
cache->values[layer],
lc->attn_weights,
scratch->grad_attn_out,
scratch->grad_q,
cache->grad_keys[layer],
cache->grad_values[layer],
pos
);
launch_linear_backward(
model->attn_wq[layer].data,
lc->attn_norm,
scratch->grad_q,
model->attn_wq[layer].grad,
scratch->grad_attn_norm,
N_EMBD,
N_EMBD
);
launch_linear_backward(
model->attn_wk[layer].data,
lc->attn_norm,
cache->grad_k_accum[layer][pos],
model->attn_wk[layer].grad,
scratch->grad_attn_norm,
N_EMBD,
N_EMBD
);
launch_linear_backward(
model->attn_wv[layer].data,
lc->attn_norm,
cache->grad_v_accum[layer][pos],
model->attn_wv[layer].grad,
scratch->grad_attn_norm,
N_EMBD,
N_EMBD
);
zero_device(scratch->grad_norm_tmp, N_EMBD);
launch_rmsnorm_backward(lc->x_in, scratch->grad_attn_norm, scratch->grad_norm_tmp, N_EMBD);
launch_vec_add_inplace(scratch->grad_x_in, scratch->grad_norm_tmp, N_EMBD);
copy_device(scratch->grad_x, scratch->grad_x_in, N_EMBD);
}
zero_device(scratch->grad_embed_in, N_EMBD);
launch_rmsnorm_backward(pc->embed_in, scratch->grad_x, scratch->grad_embed_in, N_EMBD);
launch_row_add(model->wte.grad + (size_t)pc->token_id * N_EMBD, scratch->grad_embed_in, N_EMBD);
launch_row_add(model->wpe.grad + (size_t)pos * N_EMBD, scratch->grad_embed_in, N_EMBD);
}
}
static void adam_update_param(Param *param, int step, int total_steps, float beta1_pow, float beta2_pow) {
float lr_t = LEARNING_RATE * (1.0f - (float)step / (float)total_steps);
int count = (int)param_count(param);
launch_adam_update(param->data, param->grad, param->m, param->v, lr_t, 1.0f - beta1_pow, 1.0f - beta2_pow, count);
}
static void adam_update_model(Model *model, int step, int total_steps) {
float beta1_pow = powf(BETA1, (float)(step + 1));
float beta2_pow = powf(BETA2, (float)(step + 1));
adam_update_param(&model->wte, step, total_steps, beta1_pow, beta2_pow);
adam_update_param(&model->wpe, step, total_steps, beta1_pow, beta2_pow);
adam_update_param(&model->lm_head, step, total_steps, beta1_pow, beta2_pow);
for (int layer = 0; layer < N_LAYER; ++layer) {
adam_update_param(&model->attn_wq[layer], step, total_steps, beta1_pow, beta2_pow);
adam_update_param(&model->attn_wk[layer], step, total_steps, beta1_pow, beta2_pow);
adam_update_param(&model->attn_wv[layer], step, total_steps, beta1_pow, beta2_pow);
adam_update_param(&model->attn_wo[layer], step, total_steps, beta1_pow, beta2_pow);
adam_update_param(&model->mlp_fc1[layer], step, total_steps, beta1_pow, beta2_pow);
adam_update_param(&model->mlp_fc2[layer], step, total_steps, beta1_pow, beta2_pow);
}
}
static int sample_from_probs(const float *probs, int n) {
float r = rand_uniform();
float cumulative = 0.0f;
for (int i = 0; i < n; ++i) {
cumulative += probs[i];
if (r <= cumulative) {
return i;
}
}
return n - 1;
}
static void generate_samples(Model *model, TrainCache *cache, const Dataset *dataset, int num_samples, float temperature) {
float *host_probs = (float *)xmalloc((size_t)dataset->vocab_size * sizeof(float));
printf("\n--- inference (new, hallucinated names) ---\n");
for (int sample_idx = 0; sample_idx < num_samples; ++sample_idx) {
int token_id = dataset->bos_id;
char sample[BLOCK_SIZE + 1];
int length = 0;
for (int pos = 0; pos < BLOCK_SIZE; ++pos) {
forward_position(model, cache, dataset->vocab_size, token_id, pos);
PositionCache *pc = &cache->pos[pos];
launch_scaled_softmax(pc->logits, pc->probs, 1.0f / temperature, dataset->vocab_size);
copy_to_host(host_probs, pc->probs, dataset->vocab_size);
token_id = sample_from_probs(host_probs, dataset->vocab_size);
if (token_id == dataset->bos_id) {
break;
}
sample[length++] = (char)dataset->id_to_char[token_id];
}
sample[length] = '\0';
printf("sample %2d: %s\n", sample_idx + 1, sample);
}
free(host_probs);
}
static Options parse_options(int argc, char **argv) {
Options opts;
opts.num_steps = 1000;
opts.num_samples = 20;
opts.temperature = 0.5f;
opts.input_path = "input.txt";
for (int i = 1; i < argc; ++i) {
if (strcmp(argv[i], "--steps") == 0 && i + 1 < argc) {
opts.num_steps = atoi(argv[++i]);
} else if (strcmp(argv[i], "--samples") == 0 && i + 1 < argc) {
opts.num_samples = atoi(argv[++i]);
} else if (strcmp(argv[i], "--temperature") == 0 && i + 1 < argc) {
opts.temperature = (float)atof(argv[++i]);
} else if (strcmp(argv[i], "--input") == 0 && i + 1 < argc) {
opts.input_path = argv[++i];
} else {
fprintf(stderr, "usage: %s [--steps N] [--samples N] [--temperature T] [--input PATH]\n", argv[0]);
exit(1);
}
}
if (opts.num_steps <= 0) {
fprintf(stderr, "--steps must be positive\n");
exit(1);
}
if (opts.num_samples < 0) {
fprintf(stderr, "--samples must be non-negative\n");
exit(1);
}
if (opts.temperature <= 0.0f) {
fprintf(stderr, "--temperature must be > 0\n");
exit(1);
}
return opts;
}
int main(int argc, char **argv) {
srand(42);
Options opts = parse_options(argc, argv);
int device_count = 0;
cudaError_t device_err = cudaGetDeviceCount(&device_count);
if (device_err != cudaSuccess || device_count == 0) {
fprintf(stderr, "microgpt.cu requires a CUDA-capable GPU at runtime");
if (device_err != cudaSuccess) {
fprintf(stderr, " (%s)", cudaGetErrorString(device_err));
}
fprintf(stderr, "\n");
return 1;
}
CHECK_CUBLAS(cublasCreate(&g_cublas));
Dataset dataset;
load_dataset(&dataset, opts.input_path);
printf("num docs: %d\n", dataset.docs.count);
printf("vocab size: %d\n", dataset.vocab_size);
FloatArena param_arena;
init_arena(&param_arena, param_storage_count(dataset.vocab_size));
Model model;
init_model(&model, &param_arena, dataset.vocab_size);
printf("num params: %zu\n", total_params(&model));
FloatArena cache_arena;
init_arena(&cache_arena, train_cache_storage_count(dataset.vocab_size));
TrainCache cache;
init_train_cache(&cache, dataset.vocab_size, &cache_arena);
FloatArena scratch_arena;
init_arena(&scratch_arena, scratch_storage_count(dataset.vocab_size));
Scratch scratch;
init_scratch(&scratch, dataset.vocab_size, &scratch_arena);
int tokens[BLOCK_SIZE + 2];
for (int step = 0; step < opts.num_steps; ++step) {
const char *doc = dataset.docs.items[step % dataset.docs.count];
int token_count = tokenize_doc(&dataset, doc, tokens);
int n = token_count - 1;
if (n > BLOCK_SIZE) {
n = BLOCK_SIZE;
}
float loss = forward_sequence(&model, &cache, &dataset, tokens, n);
backward_sequence(&model, &cache, &scratch, &dataset, n);
adam_update_model(&model, step, opts.num_steps);
printf("step %4d / %4d | loss %.4f\r", step + 1, opts.num_steps, loss);
fflush(stdout);
}
generate_samples(&model, &cache, &dataset, opts.num_samples, opts.temperature);
CHECK_CUDA(cudaDeviceSynchronize());
CHECK_CUBLAS(cublasDestroy(g_cublas));
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment