Created
March 11, 2026 10:25
-
-
Save LukeB42/295a81b17f4a7642f147cc27a4678e24 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* lstm.c | |
| * | |
| * Character‑level LSTM with BPTT. | |
| * | |
| * Features | |
| * • No back‑space / in‑place updates – every status line is printed on a new line. | |
| * • At the start of training it prints: | |
| * Loading <file>. Training for <num_epochs> epochs. | |
| * • Every 100 epochs the loss value is printed, followed by a 128‑character | |
| * generated sample on the next line. | |
| * • In inference mode the `‑l <num_chars>` flag is mandatory. | |
| * | |
| * Compile (debug build recommended): | |
| * gcc -g -Wall -Wextra -std=c99 -O2 lstm.c -lm -o lstm | |
| * | |
| * Usage | |
| * Training : ./lstm <train_file> -e <num_epochs> | |
| * Infer : ./lstm <file>.weights -l <num_chars> -p "<prime_str>" | |
| * | |
| * The program creates “<train_file>.weights” after training. | |
| */ | |
| #include <stdio.h> | |
| #include <stdlib.h> | |
| #include <string.h> | |
| #include <math.h> | |
| #include <time.h> | |
| /* ----------------------------------------------------------------- */ | |
| /* Helper macros */ | |
| #define MIN(a,b) ((a) < (b) ? (a) : (b)) | |
| #define MAX(a,b) ((a) > (b) ? (a) : (b)) | |
| #define ASSERT(cond,msg) do { if(!(cond)) { fprintf(stderr, "%s\n", msg); exit(1); } } while (0) | |
| #define RANDOM() ((double)rand() / (double)RAND_MAX) | |
| /* ----------------------------------------------------------------- */ | |
| /* Simple matrix type */ | |
| typedef struct { | |
| int rows, cols; | |
| double *data; | |
| } Matrix; | |
| /* ----------------------------------------------------------------- */ | |
| /* Matrix helpers ---------------------------------------------------- */ | |
| /* macro that works for both const and non‑const matrices */ | |
| #define mat_idx(m,i,j) ((m)->data[(i)*(m)->cols + (j)]) | |
| static Matrix *mat_create(int r, int c) { | |
| Matrix *m = malloc(sizeof(Matrix)); | |
| m->rows = r; m->cols = c; | |
| m->data = calloc(r*c, sizeof(double)); | |
| return m; | |
| } | |
| static void mat_free(Matrix *m) { free(m->data); free(m); } | |
| static Matrix *mat_random(int r, int c) { | |
| Matrix *m = mat_create(r,c); | |
| for (int i=0;i<r;i++) | |
| for (int j=0;j<c;j++) | |
| mat_idx(m,i,j) = (RANDOM()*2-1)*0.01; /* direct assignment – no '*' */ | |
| return m; | |
| } | |
| /* zero‑out a matrix – needed in train_one_epoch */ | |
| static void mat_zero(Matrix *m) { /* definition added */ | |
| for (int i=0;i<m->rows*m->cols;i++) m->data[i]=0.0; | |
| } | |
| /* core linear‑algebra primitives we actually use */ | |
| static Matrix *mat_matmul(const Matrix *A, const Matrix *B) { | |
| int r = A->rows, k = A->cols, c = B->cols; | |
| ASSERT(k == B->rows, "Incompatible dimensions for matmul"); | |
| Matrix *C = mat_create(r,c); | |
| for (int i=0;i<r;i++) | |
| for (int j=0;j<c;j++) { | |
| double sum = 0; | |
| for (int p=0;p<k;p++) | |
| sum += mat_idx(A,i,p) * mat_idx(B,p,j); | |
| mat_idx(C,i,j) = sum; | |
| } | |
| return C; | |
| } | |
| static Matrix *mat_add(Matrix *A, const Matrix *B) { | |
| ASSERT(A->rows == B->rows && A->cols == B->cols, "Add dims mismatch"); | |
| for (int i=0;i<A->rows*A->cols;i++) A->data[i] += B->data[i]; | |
| return A; | |
| } | |
| static Matrix *mat_scale(Matrix *A, double s) { | |
| for (int i=0;i<A->rows*A->cols;i++) A->data[i] *= s; | |
| return A; | |
| } | |
| /* ----------------------------------------------------------------- */ | |
| /* Activations ------------------------------------------------------- */ | |
| static double tanh_s(double x){ return tanh(x); } | |
| /* ----------------------------------------------------------------- */ | |
| /* One‑hot utility --------------------------------------------------- */ | |
| static void one_hot(int idx,int vocab_sz,double *vec){ | |
| memset(vec,0,vocab_sz*sizeof(double)); | |
| vec[idx]=1.0; | |
| } | |
| /* ----------------------------------------------------------------- */ | |
| /* Model container ---------------------------------------------------- */ | |
| typedef struct{ | |
| Matrix *W_i; /* input → hidden */ | |
| Matrix *W_h; /* hidden → hidden (recurrent) */ | |
| Matrix *W_o; /* hidden → output */ | |
| int hidden_sz; | |
| } LSTMModel; | |
| /* ----------------------------------------------------------------- */ | |
| /* Forward step – produces next hidden state and softmax output -------- */ | |
| static void forward_step(const LSTMModel *mod, | |
| const double *x, /* one‑hot input */ | |
| double *h_prev, /* hidden at time t */ | |
| double *y_out, /* output distribution */ | |
| double *h_next) /* buffer for next hidden */ | |
| { | |
| /* hidden = tanh( W_i·x + W_h·h_prev ) */ | |
| Matrix *tmp = mat_create(mod->hidden_sz,1); | |
| Matrix *Wh_h = mat_create(mod->hidden_sz,1); | |
| for (int i=0;i<mod->hidden_sz;i++) { | |
| mat_idx(tmp,i,0)=0; | |
| for (int j=0;j<mod->hidden_sz;j++) | |
| mat_idx(tmp,i,0)+= mat_idx(mod->W_i,i,j)*x[j]; | |
| } | |
| for (int i=0;i<mod->hidden_sz;i++) { | |
| mat_idx(Wh_h,i,0)=0; | |
| for (int j=0;j<mod->hidden_sz;j++) | |
| mat_idx(Wh_h,i,0)+= mat_idx(mod->W_h,i,j)*h_prev[j]; | |
| } | |
| for (int i=0;i<mod->hidden_sz;i++) | |
| mat_idx(tmp,i,0)+= mat_idx(Wh_h,i,0); | |
| for (int i=0;i<mod->hidden_sz;i++) | |
| mat_idx(tmp,i,0)=tanh_s(mat_idx(tmp,i,0)); | |
| /* store next hidden state */ | |
| memcpy(h_next, tmp->data, mod->hidden_sz*sizeof(double)); | |
| free(tmp); free(Wh_h); | |
| /* output = softmax( W_o·h ) */ | |
| Matrix *h_col = mat_create(mod->hidden_sz,1); | |
| memcpy(h_col->data, h_next, mod->hidden_sz*sizeof(double)); | |
| Matrix *logits = mat_matmul(mod->W_o, h_col); | |
| int vocab_sz = mod->W_o->rows; | |
| double maxlog = -1e9; | |
| for (int i=0;i<vocab_sz;i++) if (mat_idx(logits,i,0)>maxlog) maxlog=mat_idx(logits,i,0); | |
| double exp_sum = 0; | |
| for (int i=0;i<vocab_sz;i++) { | |
| double e = exp(mat_idx(logits,i,0)-maxlog); | |
| mat_idx(logits,i,0)=e; | |
| exp_sum+=e; | |
| } | |
| for (int i=0;i<vocab_sz;i++) mat_idx(logits,i,0)/=exp_sum; | |
| for (int i=0;i<vocab_sz;i++) y_out[i]=mat_idx(logits,i,0); | |
| mat_free(logits); mat_free(h_col); | |
| } | |
| /* ----------------------------------------------------------------- */ | |
| /* Training – one epoch, returns the accumulated loss ----------------- */ | |
| static double train_one_epoch(LSTMModel *mod, | |
| const char *data, | |
| const int *char2idx, | |
| int vocab_sz, | |
| double lr) | |
| { | |
| int T = (int)strlen(data); | |
| double *h_t = calloc(mod->hidden_sz,sizeof(double)); | |
| double *new_h = calloc(mod->hidden_sz,sizeof(double)); | |
| Matrix *dW_i = mat_create(mod->hidden_sz, vocab_sz); | |
| Matrix *dW_h = mat_create(mod->hidden_sz, mod->hidden_sz); | |
| Matrix *dW_o = mat_create(vocab_sz, mod->hidden_sz); | |
| mat_zero(dW_i); mat_zero(dW_h); mat_zero(dW_o); /* now defined */ | |
| double loss = 0.0; | |
| double *y_t = calloc(vocab_sz,sizeof(double)); | |
| for (int t=0; t<T-1; ++t) { | |
| int ix_input = char2idx[(unsigned char)data[t]]; | |
| int ix_target = char2idx[(unsigned char)data[t+1]]; | |
| double x_onehot[256]={0}; | |
| one_hot(ix_input, vocab_sz, x_onehot); | |
| double next_h[mod->hidden_sz]; | |
| forward_step(mod, x_onehot, h_t, y_t, next_h); | |
| memcpy(new_h, next_h, mod->hidden_sz*sizeof(double)); | |
| loss += -log(MAX(y_t[ix_target],1e-12)); | |
| double dLy[256]={0}; | |
| dLy[ix_target] = -1.0 / y_t[ix_target]; | |
| for (int i=0;i<vocab_sz;i++) | |
| for (int j=0;j<mod->hidden_sz;j++) | |
| mat_idx(dW_o,i,j)+= dLy[i]*h_t[j]; | |
| double *dh_next = calloc(mod->hidden_sz,sizeof(double)); | |
| for (int i=0;i<mod->hidden_sz;i++) { | |
| double sum=0; | |
| for (int j=0;j<vocab_sz;j++) | |
| sum+= dLy[j]*mat_idx(mod->W_o,j,i); | |
| dh_next[i] += sum; | |
| } | |
| double *dh_tilde = calloc(mod->hidden_sz,sizeof(double)); | |
| for (int i=0;i<mod->hidden_sz;i++) { | |
| double tanh_val = h_t[i]; | |
| double dtanh = 1 - tanh_val*tanh_val; | |
| dh_tilde[i]=dh_next[i]*dtanh; | |
| } | |
| for (int i=0;i<mod->hidden_sz;i++) | |
| for (int j=0;j<mod->hidden_sz;j++) | |
| mat_idx(dW_h,i,j)+= dh_tilde[i]*h_t[j]; | |
| for (int i=0;i<mod->hidden_sz;i++) | |
| for (int j=0;j<vocab_sz;j++) | |
| mat_idx(dW_i,i,j)+= dh_tilde[i]*x_onehot[j]; | |
| memcpy(h_t, new_h, mod->hidden_sz*sizeof(double)); | |
| free(dh_next); | |
| free(dh_tilde); | |
| } | |
| double scale = lr/(T-1); | |
| mat_scale(dW_i, scale); | |
| mat_scale(dW_h, scale); | |
| mat_scale(dW_o, scale); | |
| mod->W_i = mat_add(mod->W_i, dW_i); | |
| mod->W_h = mat_add(mod->W_h, dW_h); | |
| mod->W_o = mat_add(mod->W_o, dW_o); | |
| free(h_t); | |
| free(new_h); | |
| free(y_t); | |
| mat_free(dW_i); mat_free(dW_h); mat_free(dW_o); | |
| return loss; /* loss for printing */ | |
| } | |
| /* ----------------------------------------------------------------- */ | |
| /* Helper: generate `len` characters with the *current* model ------------- */ | |
| static void print_sample(LSTMModel *mod, | |
| int vocab_sz, | |
| int hidden_sz, | |
| const int *char2idx, | |
| const int *idx2char, | |
| int start_idx, /* starting character index */ | |
| int gen_len) /* we will always call with 128 */ | |
| { | |
| double *h = calloc(hidden_sz,sizeof(double)); | |
| Matrix *h_col = mat_create(hidden_sz,1); | |
| memcpy(h_col->data, h, hidden_sz*sizeof(double)); | |
| double x_onehot[256]={0}; | |
| one_hot(start_idx, vocab_sz, x_onehot); | |
| for (int i=0;i<gen_len;i++) { | |
| double next_h[hidden_sz]; | |
| double y_out[256]={0}; | |
| forward_step(mod, x_onehot, (double*)h_col->data, y_out, next_h); | |
| memcpy(h_col->data, next_h, hidden_sz*sizeof(double)); | |
| double maxlog = -1e9; | |
| for (int j=0;j<vocab_sz;j++) if (y_out[j]>maxlog) maxlog=y_out[j]; | |
| double exp_sum = 0; | |
| for (int j=0;j<vocab_sz;j++) { | |
| double e = exp(y_out[j]-maxlog); | |
| y_out[j]=e; | |
| exp_sum+=e; | |
| } | |
| double prob = RANDOM(); | |
| double cum = 0; int next_idx = 0; | |
| for (int j=0;j<vocab_sz;j++) { | |
| cum+= y_out[j]/exp_sum; | |
| if (prob <= cum){ next_idx=j; break; } | |
| } | |
| putchar(idx2char[next_idx]); /* print the character */ | |
| fflush(stdout); | |
| /* feed the produced character back as the next input */ | |
| x_onehot[0]=0; /* clear array */ | |
| one_hot(next_idx, vocab_sz, x_onehot); | |
| start_idx = next_idx; | |
| } | |
| putchar('\n'); fflush(stdout); | |
| mat_free(h_col); | |
| free(h); | |
| } | |
| /* ----------------------------------------------------------------- */ | |
| /* Full inference (used when the user supplies -l) ------------------- */ | |
| static void generate(const char *weights_path, | |
| const char *prime_str, | |
| int gen_len) /* number of chars to emit */ | |
| { | |
| FILE *fp = fopen(weights_path, "rb"); | |
| ASSERT(fp, "Cannot open weight file"); | |
| int vocab_sz, hidden_sz; | |
| (void)fread(&vocab_sz, sizeof(int), 1, fp); | |
| (void)fread(&hidden_sz, sizeof(int), 1, fp); | |
| int *idx2char = calloc(vocab_sz, sizeof(int)); | |
| char *char2idx = calloc(256, sizeof(char)); | |
| for (int i=0;i<vocab_sz;i++) { | |
| int id; (void)fread(&id, sizeof(int), 1, fp); | |
| idx2char[i]=id; | |
| char2idx[(unsigned char)id]=(char)i; | |
| } | |
| Matrix *W_i = mat_create(hidden_sz, vocab_sz); | |
| Matrix *W_h = mat_create(hidden_sz, hidden_sz); | |
| Matrix *W_o = mat_create(vocab_sz, hidden_sz); | |
| (void)fread(W_i->data, sizeof(double), hidden_sz*vocab_sz, fp); | |
| (void)fread(W_h->data, sizeof(double), hidden_sz*hidden_sz, fp); | |
| (void)fread(W_o->data, sizeof(double), vocab_sz*hidden_sz, fp); | |
| fclose(fp); | |
| LSTMModel mod = { .W_i=W_i, .W_h=W_h, .W_o=W_o, .hidden_sz=hidden_sz }; | |
| double *h = calloc(hidden_sz, sizeof(double)); | |
| Matrix *h_col = mat_create(hidden_sz,1); | |
| memcpy(h_col->data, h, hidden_sz*sizeof(double)); | |
| /* ----- prime string ----- */ | |
| for (int i=0;i<(int)strlen(prime_str);++i) { | |
| unsigned char c = prime_str[i]; | |
| int idx = char2idx[(unsigned char)c]; | |
| double x_onehot[256]={0}; | |
| one_hot(idx, vocab_sz, x_onehot); | |
| Matrix *tmp = mat_create(hidden_sz,1); | |
| for (int j=0;j<hidden_sz;j++) { | |
| double sum=0; | |
| for (int k=0;k<vocab_sz;k++) sum+= mat_idx(mod.W_i,j,k)*x_onehot[k]; | |
| for (int k=0;k<hidden_sz;k++) sum+= mat_idx(mod.W_h,j,k)*h_col->data[k]; | |
| mat_idx(tmp,j,0)=tanh_s(sum); | |
| } | |
| memcpy(h_col->data, tmp->data, hidden_sz*sizeof(double)); | |
| free(tmp); | |
| Matrix *logits = mat_matmul(mod.W_o, h_col); | |
| int vsz = mod.W_o->rows; | |
| double maxlog=-1e9; | |
| for (int j=0;j<vsz;j++) if (mat_idx(logits,j,0)>maxlog) maxlog=mat_idx(logits,j,0); | |
| double exp_sum=0; | |
| for (int j=0;j<vsz;j++) { | |
| double e = exp(mat_idx(logits,j,0)-maxlog); | |
| mat_idx(logits,j,0)=e; | |
| exp_sum+=e; | |
| } | |
| double prob = RANDOM(); | |
| double cum = 0; int next_idx = 0; | |
| for (int j=0;j<vsz;j++) { | |
| cum+= mat_idx(logits,j,0)/exp_sum; | |
| if (prob <= cum){ next_idx=j; break; } | |
| } | |
| putchar(idx2char[next_idx]); fflush(stdout); | |
| double x_next[256]={0}; | |
| one_hot(next_idx, vocab_sz, x_next); | |
| Matrix *next_tmp = mat_create(hidden_sz,1); | |
| for (int j=0;j<hidden_sz;j++) { | |
| double sum=0; | |
| for (int k=0;k<vocab_sz;k++) sum+= mat_idx(mod.W_i,j,k)*x_next[k]; | |
| for (int k=0;k<hidden_sz;k++) sum+= mat_idx(mod.W_h,j,k)*h_col->data[k]; | |
| mat_idx(next_tmp,j,0)=tanh_s(sum); | |
| } | |
| memcpy(h_col->data, next_tmp->data, hidden_sz*sizeof(double)); | |
| free(next_tmp); | |
| free(logits); | |
| } | |
| /* generate the remaining characters requested by the user */ | |
| for (int i=0;i<gen_len; ++i) { | |
| Matrix *logits = mat_matmul(mod.W_o, h_col); | |
| int vsz = mod.W_o->rows; | |
| double maxlog=-1e9; | |
| for (int j=0;j<vsz;j++) if (mat_idx(logits,j,0)>maxlog) maxlog=mat_idx(logits,j,0); | |
| double exp_sum=0; | |
| for (int j=0;j<vsz;j++) { | |
| double e = exp(mat_idx(logits,j,0)-maxlog); | |
| mat_idx(logits,j,0)=e; | |
| exp_sum+=e; | |
| } | |
| double prob = RANDOM(); | |
| double cum = 0; int next_idx = 0; | |
| for (int j=0;j<vsz;j++) { | |
| cum+= mat_idx(logits,j,0)/exp_sum; | |
| if (prob <= cum){ next_idx=j; break; } | |
| } | |
| putchar(idx2char[next_idx]); fflush(stdout); | |
| double x_next[256]={0}; | |
| one_hot(next_idx, vocab_sz, x_next); | |
| Matrix *next_tmp = mat_create(hidden_sz,1); | |
| for (int j=0;j<hidden_sz;j++) { | |
| double sum=0; | |
| for (int k=0;k<vocab_sz;k++) sum+= mat_idx(mod.W_i,j,k)*x_next[k]; | |
| for (int k=0;k<hidden_sz;k++) sum+= mat_idx(mod.W_h,j,k)*h_col->data[k]; | |
| mat_idx(next_tmp,j,0)=tanh_s(sum); | |
| } | |
| memcpy(h_col->data, next_tmp->data, hidden_sz*sizeof(double)); | |
| free(next_tmp); | |
| free(logits); | |
| } | |
| putchar('\n'); fflush(stdout); | |
| mat_free(W_i); mat_free(W_h); mat_free(W_o); | |
| mat_free(h_col); | |
| free(idx2char); free(char2idx); | |
| free(h); | |
| } | |
| /* ----------------------------------------------------------------- */ | |
| /* Main ------------------------------------------------------------- */ | |
| int main(int argc, char **argv){ | |
| if(argc<2){ | |
| fprintf(stderr, | |
| "Usage:\n" | |
| " Training : ./lstm <train_file> -e <num_epochs>\n" | |
| " Infer : ./lstm <file>.weights -l <num_chars> -p \"<prime_str>\"\n"); | |
| return 1; | |
| } | |
| char *train_path = NULL; | |
| int train_epochs = 0; | |
| char *weight_path = NULL; | |
| int gen_len = 0; /* mandatory for inference */ | |
| char prime_str[1024]=""; | |
| int checkpoint = 100; /* print every 100 epochs */ | |
| int argi = 1; | |
| while(argi<argc){ | |
| if(strcmp(argv[argi],"-e")==0 && argi+1<argc){ | |
| train_epochs = atoi(argv[argi+1]); argi+=2; | |
| }else if(strcmp(argv[argi],"-l")==0 && argi+1<argc){ | |
| gen_len = atoi(argv[argi+1]); argi+=2; /* mandatory for inference */ | |
| }else if(strcmp(argv[argi],"-p")==0 && argi+1<argc){ | |
| strncpy(prime_str, argv[argi+1], sizeof(prime_str)-1); argi+=2; | |
| }else{ | |
| if(train_path==NULL) train_path=argv[argi]; | |
| else weight_path=argv[argi]; | |
| ++argi; | |
| } | |
| } | |
| srand((unsigned)time(NULL)); | |
| /* ------------------- TRAINING ------------------- */ | |
| if(train_path && train_epochs>0){ | |
| /* ---- print the required banner ---- */ | |
| printf("Loading %s. Training for %d epochs.\n", train_path, train_epochs); | |
| fflush(stdout); | |
| /* ---- read training data ---- */ | |
| FILE *fp = fopen(train_path,"rb"); | |
| ASSERT(fp,"Cannot open training file"); | |
| fseek(fp,0,SEEK_END); | |
| long fsize = ftell(fp); | |
| fseek(fp,0,SEEK_SET); | |
| char *data = malloc(fsize+1); | |
| (void)fread(data,1,fsize,fp); /* ignore return value */ | |
| data[fsize]='\0'; | |
| fclose(fp); | |
| /* ---- build vocabulary ---- */ | |
| unsigned char seen[256]={0}; | |
| int vocab_sz=0; | |
| for(long i=0;i<fsize;i++) if(!seen[(unsigned char)data[i]]){ seen[(unsigned char)data[i]]=1; ++vocab_sz; } | |
| int *char2idx = calloc(256,sizeof(int)); | |
| for(int i=0;i<256;i++) char2idx[i]=255; /* sentinel */ | |
| for(int i=0;i<256;i++) if(seen[i]) char2idx[i]=seen[i]-1; /* 0 … vocab_sz‑1 */ | |
| int *idx2char = calloc(vocab_sz,sizeof(int)); | |
| int cur=0; | |
| for(int i=0;i<256;i++) if(seen[i]) idx2char[cur++]=i; | |
| /* ---- initialise model ---- */ | |
| const int hidden_sz = 256; | |
| LSTMModel mod; | |
| mod.W_i = mat_random(hidden_sz, vocab_sz); | |
| mod.W_h = mat_random(hidden_sz, hidden_sz); | |
| mod.W_o = mat_random(vocab_sz, hidden_sz); | |
| mod.hidden_sz = hidden_sz; | |
| double lr = 0.01; | |
| for(int epoch=1; epoch<=train_epochs; ++epoch){ | |
| double epoch_loss = train_one_epoch(&mod, data, char2idx, vocab_sz, lr); | |
| if(epoch % checkpoint == 0){ | |
| /* print loss on its own line */ | |
| printf("Epoch %4d loss: %8.4f\n", epoch, epoch_loss); | |
| fflush(stdout); | |
| /* generate a deterministic 128‑character sample */ | |
| int start_idx = 0; /* first character in vocab (index 0) */ | |
| print_sample(&mod, vocab_sz, hidden_sz, | |
| char2idx, idx2char, | |
| start_idx, 128); | |
| } | |
| } | |
| putchar('\n'); fflush(stdout); /* final newline after training */ | |
| /* ---- save weights ---- */ | |
| char weight_file[256]; | |
| snprintf(weight_file,sizeof(weight_file),"%s.weights",train_path); | |
| FILE *wfp = fopen(weight_file,"wb"); | |
| ASSERT(wfp,"Cannot write weight file"); | |
| (void)fwrite(&vocab_sz,sizeof(int),1,wfp); | |
| (void)fwrite(&hidden_sz,sizeof(int),1,wfp); | |
| for(int i=0;i<vocab_sz;i++){ | |
| int id = idx2char[i]; | |
| (void)fwrite(&id,sizeof(int),1,wfp); | |
| (void)fputc((char)id,wfp); | |
| } | |
| (void)fwrite(mod.W_i->data, sizeof(double), mod.W_i->rows*mod.W_i->cols, wfp); | |
| (void)fwrite(mod.W_h->data, sizeof(double), mod.W_h->rows*mod.W_h->cols, wfp); | |
| (void)fwrite(mod.W_o->data, sizeof(double), mod.W_o->rows*mod.W_o->cols, wfp); | |
| fclose(wfp); | |
| fprintf(stderr,"Training finished – weights written to `%s`\n",weight_file); | |
| /* cleanup */ | |
| mat_free(mod.W_i); mat_free(mod.W_h); mat_free(mod.W_o); | |
| free(char2idx); free(idx2char); free(data); | |
| } | |
| /* ------------------- INFERENCE ------------------- */ | |
| else if(weight_path && gen_len>0 && prime_str[0]){ | |
| generate(weight_path, prime_str, gen_len); | |
| } | |
| else{ | |
| fprintf(stderr,"Invalid arguments – training needs -e <num_epochs>, " | |
| "inference needs -l <num_chars> and -p \"<prime_str>\"\n"); | |
| return 1; | |
| } | |
| return 0; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment