Skip to content

Instantly share code, notes, and snippets.

@LukeB42
Created March 11, 2026 10:25
Show Gist options
  • Select an option

  • Save LukeB42/295a81b17f4a7642f147cc27a4678e24 to your computer and use it in GitHub Desktop.

Select an option

Save LukeB42/295a81b17f4a7642f147cc27a4678e24 to your computer and use it in GitHub Desktop.
/* lstm.c
*
* Character‑level LSTM with BPTT.
*
* Features
* • No back‑space / in‑place updates – every status line is printed on a new line.
* • At the start of training it prints:
* Loading <file>. Training for <num_epochs> epochs.
* • Every 100 epochs the loss value is printed, followed by a 128‑character
* generated sample on the next line.
* • In inference mode the `‑l <num_chars>` flag is mandatory.
*
* Compile (debug build recommended):
* gcc -g -Wall -Wextra -std=c99 -O2 lstm.c -lm -o lstm
*
* Usage
* Training : ./lstm <train_file> -e <num_epochs>
* Infer : ./lstm <file>.weights -l <num_chars> -p "<prime_str>"
*
* The program creates “<train_file>.weights” after training.
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>
/* ----------------------------------------------------------------- */
/* Helper macros */
#define MIN(a,b) ((a) < (b) ? (a) : (b))
#define MAX(a,b) ((a) > (b) ? (a) : (b))
#define ASSERT(cond,msg) do { if(!(cond)) { fprintf(stderr, "%s\n", msg); exit(1); } } while (0)
#define RANDOM() ((double)rand() / (double)RAND_MAX)
/* ----------------------------------------------------------------- */
/* Simple matrix type */
typedef struct {
int rows, cols;
double *data;
} Matrix;
/* ----------------------------------------------------------------- */
/* Matrix helpers ---------------------------------------------------- */
/* macro that works for both const and non‑const matrices */
#define mat_idx(m,i,j) ((m)->data[(i)*(m)->cols + (j)])
static Matrix *mat_create(int r, int c) {
Matrix *m = malloc(sizeof(Matrix));
m->rows = r; m->cols = c;
m->data = calloc(r*c, sizeof(double));
return m;
}
static void mat_free(Matrix *m) { free(m->data); free(m); }
static Matrix *mat_random(int r, int c) {
Matrix *m = mat_create(r,c);
for (int i=0;i<r;i++)
for (int j=0;j<c;j++)
mat_idx(m,i,j) = (RANDOM()*2-1)*0.01; /* direct assignment – no '*' */
return m;
}
/* zero‑out a matrix – needed in train_one_epoch */
static void mat_zero(Matrix *m) { /* definition added */
for (int i=0;i<m->rows*m->cols;i++) m->data[i]=0.0;
}
/* core linear‑algebra primitives we actually use */
static Matrix *mat_matmul(const Matrix *A, const Matrix *B) {
int r = A->rows, k = A->cols, c = B->cols;
ASSERT(k == B->rows, "Incompatible dimensions for matmul");
Matrix *C = mat_create(r,c);
for (int i=0;i<r;i++)
for (int j=0;j<c;j++) {
double sum = 0;
for (int p=0;p<k;p++)
sum += mat_idx(A,i,p) * mat_idx(B,p,j);
mat_idx(C,i,j) = sum;
}
return C;
}
static Matrix *mat_add(Matrix *A, const Matrix *B) {
ASSERT(A->rows == B->rows && A->cols == B->cols, "Add dims mismatch");
for (int i=0;i<A->rows*A->cols;i++) A->data[i] += B->data[i];
return A;
}
static Matrix *mat_scale(Matrix *A, double s) {
for (int i=0;i<A->rows*A->cols;i++) A->data[i] *= s;
return A;
}
/* ----------------------------------------------------------------- */
/* Activations ------------------------------------------------------- */
static double tanh_s(double x){ return tanh(x); }
/* ----------------------------------------------------------------- */
/* One‑hot utility --------------------------------------------------- */
static void one_hot(int idx,int vocab_sz,double *vec){
memset(vec,0,vocab_sz*sizeof(double));
vec[idx]=1.0;
}
/* ----------------------------------------------------------------- */
/* Model container ---------------------------------------------------- */
typedef struct{
Matrix *W_i; /* input → hidden */
Matrix *W_h; /* hidden → hidden (recurrent) */
Matrix *W_o; /* hidden → output */
int hidden_sz;
} LSTMModel;
/* ----------------------------------------------------------------- */
/* Forward step – produces next hidden state and softmax output -------- */
static void forward_step(const LSTMModel *mod,
const double *x, /* one‑hot input */
double *h_prev, /* hidden at time t */
double *y_out, /* output distribution */
double *h_next) /* buffer for next hidden */
{
/* hidden = tanh( W_i·x + W_h·h_prev ) */
Matrix *tmp = mat_create(mod->hidden_sz,1);
Matrix *Wh_h = mat_create(mod->hidden_sz,1);
for (int i=0;i<mod->hidden_sz;i++) {
mat_idx(tmp,i,0)=0;
for (int j=0;j<mod->hidden_sz;j++)
mat_idx(tmp,i,0)+= mat_idx(mod->W_i,i,j)*x[j];
}
for (int i=0;i<mod->hidden_sz;i++) {
mat_idx(Wh_h,i,0)=0;
for (int j=0;j<mod->hidden_sz;j++)
mat_idx(Wh_h,i,0)+= mat_idx(mod->W_h,i,j)*h_prev[j];
}
for (int i=0;i<mod->hidden_sz;i++)
mat_idx(tmp,i,0)+= mat_idx(Wh_h,i,0);
for (int i=0;i<mod->hidden_sz;i++)
mat_idx(tmp,i,0)=tanh_s(mat_idx(tmp,i,0));
/* store next hidden state */
memcpy(h_next, tmp->data, mod->hidden_sz*sizeof(double));
free(tmp); free(Wh_h);
/* output = softmax( W_o·h ) */
Matrix *h_col = mat_create(mod->hidden_sz,1);
memcpy(h_col->data, h_next, mod->hidden_sz*sizeof(double));
Matrix *logits = mat_matmul(mod->W_o, h_col);
int vocab_sz = mod->W_o->rows;
double maxlog = -1e9;
for (int i=0;i<vocab_sz;i++) if (mat_idx(logits,i,0)>maxlog) maxlog=mat_idx(logits,i,0);
double exp_sum = 0;
for (int i=0;i<vocab_sz;i++) {
double e = exp(mat_idx(logits,i,0)-maxlog);
mat_idx(logits,i,0)=e;
exp_sum+=e;
}
for (int i=0;i<vocab_sz;i++) mat_idx(logits,i,0)/=exp_sum;
for (int i=0;i<vocab_sz;i++) y_out[i]=mat_idx(logits,i,0);
mat_free(logits); mat_free(h_col);
}
/* ----------------------------------------------------------------- */
/* Training – one epoch, returns the accumulated loss ----------------- */
static double train_one_epoch(LSTMModel *mod,
const char *data,
const int *char2idx,
int vocab_sz,
double lr)
{
int T = (int)strlen(data);
double *h_t = calloc(mod->hidden_sz,sizeof(double));
double *new_h = calloc(mod->hidden_sz,sizeof(double));
Matrix *dW_i = mat_create(mod->hidden_sz, vocab_sz);
Matrix *dW_h = mat_create(mod->hidden_sz, mod->hidden_sz);
Matrix *dW_o = mat_create(vocab_sz, mod->hidden_sz);
mat_zero(dW_i); mat_zero(dW_h); mat_zero(dW_o); /* now defined */
double loss = 0.0;
double *y_t = calloc(vocab_sz,sizeof(double));
for (int t=0; t<T-1; ++t) {
int ix_input = char2idx[(unsigned char)data[t]];
int ix_target = char2idx[(unsigned char)data[t+1]];
double x_onehot[256]={0};
one_hot(ix_input, vocab_sz, x_onehot);
double next_h[mod->hidden_sz];
forward_step(mod, x_onehot, h_t, y_t, next_h);
memcpy(new_h, next_h, mod->hidden_sz*sizeof(double));
loss += -log(MAX(y_t[ix_target],1e-12));
double dLy[256]={0};
dLy[ix_target] = -1.0 / y_t[ix_target];
for (int i=0;i<vocab_sz;i++)
for (int j=0;j<mod->hidden_sz;j++)
mat_idx(dW_o,i,j)+= dLy[i]*h_t[j];
double *dh_next = calloc(mod->hidden_sz,sizeof(double));
for (int i=0;i<mod->hidden_sz;i++) {
double sum=0;
for (int j=0;j<vocab_sz;j++)
sum+= dLy[j]*mat_idx(mod->W_o,j,i);
dh_next[i] += sum;
}
double *dh_tilde = calloc(mod->hidden_sz,sizeof(double));
for (int i=0;i<mod->hidden_sz;i++) {
double tanh_val = h_t[i];
double dtanh = 1 - tanh_val*tanh_val;
dh_tilde[i]=dh_next[i]*dtanh;
}
for (int i=0;i<mod->hidden_sz;i++)
for (int j=0;j<mod->hidden_sz;j++)
mat_idx(dW_h,i,j)+= dh_tilde[i]*h_t[j];
for (int i=0;i<mod->hidden_sz;i++)
for (int j=0;j<vocab_sz;j++)
mat_idx(dW_i,i,j)+= dh_tilde[i]*x_onehot[j];
memcpy(h_t, new_h, mod->hidden_sz*sizeof(double));
free(dh_next);
free(dh_tilde);
}
double scale = lr/(T-1);
mat_scale(dW_i, scale);
mat_scale(dW_h, scale);
mat_scale(dW_o, scale);
mod->W_i = mat_add(mod->W_i, dW_i);
mod->W_h = mat_add(mod->W_h, dW_h);
mod->W_o = mat_add(mod->W_o, dW_o);
free(h_t);
free(new_h);
free(y_t);
mat_free(dW_i); mat_free(dW_h); mat_free(dW_o);
return loss; /* loss for printing */
}
/* ----------------------------------------------------------------- */
/* Helper: generate `len` characters with the *current* model ------------- */
static void print_sample(LSTMModel *mod,
int vocab_sz,
int hidden_sz,
const int *char2idx,
const int *idx2char,
int start_idx, /* starting character index */
int gen_len) /* we will always call with 128 */
{
double *h = calloc(hidden_sz,sizeof(double));
Matrix *h_col = mat_create(hidden_sz,1);
memcpy(h_col->data, h, hidden_sz*sizeof(double));
double x_onehot[256]={0};
one_hot(start_idx, vocab_sz, x_onehot);
for (int i=0;i<gen_len;i++) {
double next_h[hidden_sz];
double y_out[256]={0};
forward_step(mod, x_onehot, (double*)h_col->data, y_out, next_h);
memcpy(h_col->data, next_h, hidden_sz*sizeof(double));
double maxlog = -1e9;
for (int j=0;j<vocab_sz;j++) if (y_out[j]>maxlog) maxlog=y_out[j];
double exp_sum = 0;
for (int j=0;j<vocab_sz;j++) {
double e = exp(y_out[j]-maxlog);
y_out[j]=e;
exp_sum+=e;
}
double prob = RANDOM();
double cum = 0; int next_idx = 0;
for (int j=0;j<vocab_sz;j++) {
cum+= y_out[j]/exp_sum;
if (prob <= cum){ next_idx=j; break; }
}
putchar(idx2char[next_idx]); /* print the character */
fflush(stdout);
/* feed the produced character back as the next input */
x_onehot[0]=0; /* clear array */
one_hot(next_idx, vocab_sz, x_onehot);
start_idx = next_idx;
}
putchar('\n'); fflush(stdout);
mat_free(h_col);
free(h);
}
/* ----------------------------------------------------------------- */
/* Full inference (used when the user supplies -l) ------------------- */
static void generate(const char *weights_path,
const char *prime_str,
int gen_len) /* number of chars to emit */
{
FILE *fp = fopen(weights_path, "rb");
ASSERT(fp, "Cannot open weight file");
int vocab_sz, hidden_sz;
(void)fread(&vocab_sz, sizeof(int), 1, fp);
(void)fread(&hidden_sz, sizeof(int), 1, fp);
int *idx2char = calloc(vocab_sz, sizeof(int));
char *char2idx = calloc(256, sizeof(char));
for (int i=0;i<vocab_sz;i++) {
int id; (void)fread(&id, sizeof(int), 1, fp);
idx2char[i]=id;
char2idx[(unsigned char)id]=(char)i;
}
Matrix *W_i = mat_create(hidden_sz, vocab_sz);
Matrix *W_h = mat_create(hidden_sz, hidden_sz);
Matrix *W_o = mat_create(vocab_sz, hidden_sz);
(void)fread(W_i->data, sizeof(double), hidden_sz*vocab_sz, fp);
(void)fread(W_h->data, sizeof(double), hidden_sz*hidden_sz, fp);
(void)fread(W_o->data, sizeof(double), vocab_sz*hidden_sz, fp);
fclose(fp);
LSTMModel mod = { .W_i=W_i, .W_h=W_h, .W_o=W_o, .hidden_sz=hidden_sz };
double *h = calloc(hidden_sz, sizeof(double));
Matrix *h_col = mat_create(hidden_sz,1);
memcpy(h_col->data, h, hidden_sz*sizeof(double));
/* ----- prime string ----- */
for (int i=0;i<(int)strlen(prime_str);++i) {
unsigned char c = prime_str[i];
int idx = char2idx[(unsigned char)c];
double x_onehot[256]={0};
one_hot(idx, vocab_sz, x_onehot);
Matrix *tmp = mat_create(hidden_sz,1);
for (int j=0;j<hidden_sz;j++) {
double sum=0;
for (int k=0;k<vocab_sz;k++) sum+= mat_idx(mod.W_i,j,k)*x_onehot[k];
for (int k=0;k<hidden_sz;k++) sum+= mat_idx(mod.W_h,j,k)*h_col->data[k];
mat_idx(tmp,j,0)=tanh_s(sum);
}
memcpy(h_col->data, tmp->data, hidden_sz*sizeof(double));
free(tmp);
Matrix *logits = mat_matmul(mod.W_o, h_col);
int vsz = mod.W_o->rows;
double maxlog=-1e9;
for (int j=0;j<vsz;j++) if (mat_idx(logits,j,0)>maxlog) maxlog=mat_idx(logits,j,0);
double exp_sum=0;
for (int j=0;j<vsz;j++) {
double e = exp(mat_idx(logits,j,0)-maxlog);
mat_idx(logits,j,0)=e;
exp_sum+=e;
}
double prob = RANDOM();
double cum = 0; int next_idx = 0;
for (int j=0;j<vsz;j++) {
cum+= mat_idx(logits,j,0)/exp_sum;
if (prob <= cum){ next_idx=j; break; }
}
putchar(idx2char[next_idx]); fflush(stdout);
double x_next[256]={0};
one_hot(next_idx, vocab_sz, x_next);
Matrix *next_tmp = mat_create(hidden_sz,1);
for (int j=0;j<hidden_sz;j++) {
double sum=0;
for (int k=0;k<vocab_sz;k++) sum+= mat_idx(mod.W_i,j,k)*x_next[k];
for (int k=0;k<hidden_sz;k++) sum+= mat_idx(mod.W_h,j,k)*h_col->data[k];
mat_idx(next_tmp,j,0)=tanh_s(sum);
}
memcpy(h_col->data, next_tmp->data, hidden_sz*sizeof(double));
free(next_tmp);
free(logits);
}
/* generate the remaining characters requested by the user */
for (int i=0;i<gen_len; ++i) {
Matrix *logits = mat_matmul(mod.W_o, h_col);
int vsz = mod.W_o->rows;
double maxlog=-1e9;
for (int j=0;j<vsz;j++) if (mat_idx(logits,j,0)>maxlog) maxlog=mat_idx(logits,j,0);
double exp_sum=0;
for (int j=0;j<vsz;j++) {
double e = exp(mat_idx(logits,j,0)-maxlog);
mat_idx(logits,j,0)=e;
exp_sum+=e;
}
double prob = RANDOM();
double cum = 0; int next_idx = 0;
for (int j=0;j<vsz;j++) {
cum+= mat_idx(logits,j,0)/exp_sum;
if (prob <= cum){ next_idx=j; break; }
}
putchar(idx2char[next_idx]); fflush(stdout);
double x_next[256]={0};
one_hot(next_idx, vocab_sz, x_next);
Matrix *next_tmp = mat_create(hidden_sz,1);
for (int j=0;j<hidden_sz;j++) {
double sum=0;
for (int k=0;k<vocab_sz;k++) sum+= mat_idx(mod.W_i,j,k)*x_next[k];
for (int k=0;k<hidden_sz;k++) sum+= mat_idx(mod.W_h,j,k)*h_col->data[k];
mat_idx(next_tmp,j,0)=tanh_s(sum);
}
memcpy(h_col->data, next_tmp->data, hidden_sz*sizeof(double));
free(next_tmp);
free(logits);
}
putchar('\n'); fflush(stdout);
mat_free(W_i); mat_free(W_h); mat_free(W_o);
mat_free(h_col);
free(idx2char); free(char2idx);
free(h);
}
/* ----------------------------------------------------------------- */
/* Main ------------------------------------------------------------- */
int main(int argc, char **argv){
if(argc<2){
fprintf(stderr,
"Usage:\n"
" Training : ./lstm <train_file> -e <num_epochs>\n"
" Infer : ./lstm <file>.weights -l <num_chars> -p \"<prime_str>\"\n");
return 1;
}
char *train_path = NULL;
int train_epochs = 0;
char *weight_path = NULL;
int gen_len = 0; /* mandatory for inference */
char prime_str[1024]="";
int checkpoint = 100; /* print every 100 epochs */
int argi = 1;
while(argi<argc){
if(strcmp(argv[argi],"-e")==0 && argi+1<argc){
train_epochs = atoi(argv[argi+1]); argi+=2;
}else if(strcmp(argv[argi],"-l")==0 && argi+1<argc){
gen_len = atoi(argv[argi+1]); argi+=2; /* mandatory for inference */
}else if(strcmp(argv[argi],"-p")==0 && argi+1<argc){
strncpy(prime_str, argv[argi+1], sizeof(prime_str)-1); argi+=2;
}else{
if(train_path==NULL) train_path=argv[argi];
else weight_path=argv[argi];
++argi;
}
}
srand((unsigned)time(NULL));
/* ------------------- TRAINING ------------------- */
if(train_path && train_epochs>0){
/* ---- print the required banner ---- */
printf("Loading %s. Training for %d epochs.\n", train_path, train_epochs);
fflush(stdout);
/* ---- read training data ---- */
FILE *fp = fopen(train_path,"rb");
ASSERT(fp,"Cannot open training file");
fseek(fp,0,SEEK_END);
long fsize = ftell(fp);
fseek(fp,0,SEEK_SET);
char *data = malloc(fsize+1);
(void)fread(data,1,fsize,fp); /* ignore return value */
data[fsize]='\0';
fclose(fp);
/* ---- build vocabulary ---- */
unsigned char seen[256]={0};
int vocab_sz=0;
for(long i=0;i<fsize;i++) if(!seen[(unsigned char)data[i]]){ seen[(unsigned char)data[i]]=1; ++vocab_sz; }
int *char2idx = calloc(256,sizeof(int));
for(int i=0;i<256;i++) char2idx[i]=255; /* sentinel */
for(int i=0;i<256;i++) if(seen[i]) char2idx[i]=seen[i]-1; /* 0 … vocab_sz‑1 */
int *idx2char = calloc(vocab_sz,sizeof(int));
int cur=0;
for(int i=0;i<256;i++) if(seen[i]) idx2char[cur++]=i;
/* ---- initialise model ---- */
const int hidden_sz = 256;
LSTMModel mod;
mod.W_i = mat_random(hidden_sz, vocab_sz);
mod.W_h = mat_random(hidden_sz, hidden_sz);
mod.W_o = mat_random(vocab_sz, hidden_sz);
mod.hidden_sz = hidden_sz;
double lr = 0.01;
for(int epoch=1; epoch<=train_epochs; ++epoch){
double epoch_loss = train_one_epoch(&mod, data, char2idx, vocab_sz, lr);
if(epoch % checkpoint == 0){
/* print loss on its own line */
printf("Epoch %4d loss: %8.4f\n", epoch, epoch_loss);
fflush(stdout);
/* generate a deterministic 128‑character sample */
int start_idx = 0; /* first character in vocab (index 0) */
print_sample(&mod, vocab_sz, hidden_sz,
char2idx, idx2char,
start_idx, 128);
}
}
putchar('\n'); fflush(stdout); /* final newline after training */
/* ---- save weights ---- */
char weight_file[256];
snprintf(weight_file,sizeof(weight_file),"%s.weights",train_path);
FILE *wfp = fopen(weight_file,"wb");
ASSERT(wfp,"Cannot write weight file");
(void)fwrite(&vocab_sz,sizeof(int),1,wfp);
(void)fwrite(&hidden_sz,sizeof(int),1,wfp);
for(int i=0;i<vocab_sz;i++){
int id = idx2char[i];
(void)fwrite(&id,sizeof(int),1,wfp);
(void)fputc((char)id,wfp);
}
(void)fwrite(mod.W_i->data, sizeof(double), mod.W_i->rows*mod.W_i->cols, wfp);
(void)fwrite(mod.W_h->data, sizeof(double), mod.W_h->rows*mod.W_h->cols, wfp);
(void)fwrite(mod.W_o->data, sizeof(double), mod.W_o->rows*mod.W_o->cols, wfp);
fclose(wfp);
fprintf(stderr,"Training finished – weights written to `%s`\n",weight_file);
/* cleanup */
mat_free(mod.W_i); mat_free(mod.W_h); mat_free(mod.W_o);
free(char2idx); free(idx2char); free(data);
}
/* ------------------- INFERENCE ------------------- */
else if(weight_path && gen_len>0 && prime_str[0]){
generate(weight_path, prime_str, gen_len);
}
else{
fprintf(stderr,"Invalid arguments – training needs -e <num_epochs>, "
"inference needs -l <num_chars> and -p \"<prime_str>\"\n");
return 1;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment