Skip to content

Instantly share code, notes, and snippets.

@curiousluke93x
Last active February 13, 2026 02:24
Show Gist options
  • Select an option

  • Save curiousluke93x/0324ebf3ccc858b05c6775bd86b69364 to your computer and use it in GitHub Desktop.

Select an option

Save curiousluke93x/0324ebf3ccc858b05c6775bd86b69364 to your computer and use it in GitHub Desktop.
#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<string.h>
#include<time.h>
#define B 8
#define E 32
#define H 64
#define V 27
#define LR 0.001f
float wte[V*E],wpe[B*E],wq[E*E],wk[E*E],wv[E*E],w1[E*H],w2[H*E],wh[E*V];
float X[B*E],Q[E],K[B*E],Va[B*E],A[B],Z[E],Hd[H],O[E],P[V];
float gwq[E*E],gwk[E*E],gwv[E*E],gw1[E*H],gw2[H*E],gwh[E*V];
char**D;int N;
int c2i(char c){return c=='.'?0:c-'a'+1;}
char i2c(int i){return i?'a'+i-1:'.';}
void fwd(int*ctx){
int i,j,l;float mx,s;
for(i=0;i<B;i++)for(j=0;j<E;j++)X[i*E+j]=wte[ctx[i]*E+j]+wpe[i*E+j];
for(j=0;j<E;j++){Q[j]=0;for(l=0;l<E;l++)Q[j]+=X[(B-1)*E+l]*wq[l*E+j];}
for(i=0;i<B;i++)for(j=0;j<E;j++){K[i*E+j]=Va[i*E+j]=0;
for(l=0;l<E;l++){K[i*E+j]+=X[i*E+l]*wk[l*E+j];Va[i*E+j]+=X[i*E+l]*wv[l*E+j];}}
mx=-1e9;s=0;
for(i=0;i<B;i++){float d=0;for(j=0;j<E;j++)d+=Q[j]*K[i*E+j];A[i]=d/sqrtf(E);if(A[i]>mx)mx=A[i];}
for(i=0;i<B;i++)s+=(A[i]=expf(A[i]-mx));
for(i=0;i<B;i++)A[i]/=s;
for(j=0;j<E;j++){Z[j]=X[(B-1)*E+j];for(i=0;i<B;i++)Z[j]+=A[i]*Va[i*E+j];}
for(i=0;i<H;i++){float v=0;for(j=0;j<E;j++)v+=Z[j]*w1[j*H+i];Hd[i]=tanhf(v);}
for(j=0;j<E;j++){O[j]=Z[j];for(i=0;i<H;i++)O[j]+=Hd[i]*w2[i*E+j];}
mx=-1e9;s=0;
for(i=0;i<V;i++){float v=0;for(j=0;j<E;j++)v+=O[j]*wh[j*V+i];P[i]=v;if(v>mx)mx=v;}
for(i=0;i<V;i++)s+=(P[i]=expf(P[i]-mx));
for(i=0;i<V;i++)P[i]/=s;
}
void bwd(int tgt){
int i,j,l;
float dL[V],dO[E],dH[H],dZ[E],dA[B],dQ[E],dK[B*E],dV[B*E];
for(i=0;i<V;i++)dL[i]=P[i]; dL[tgt]-=1;
for(j=0;j<E;j++){dO[j]=0;for(i=0;i<V;i++){dO[j]+=dL[i]*wh[j*V+i];gwh[j*V+i]+=O[j]*dL[i];}}
for(i=0;i<H;i++){dH[i]=0;for(j=0;j<E;j++){dH[i]+=dO[j]*w2[i*E+j];gw2[i*E+j]+=Hd[i]*dO[j];}
dH[i]*=(1-Hd[i]*Hd[i]);}
for(j=0;j<E;j++){dZ[j]=dO[j];for(i=0;i<H;i++){dZ[j]+=dH[i]*w1[j*H+i];gw1[j*H+i]+=Z[j]*dH[i];}}
float dd=0;
for(i=0;i<B;i++){dA[i]=0;for(j=0;j<E;j++){dV[i*E+j]=A[i]*dZ[j];dA[i]+=Va[i*E+j]*dZ[j];}dd+=dA[i]*A[i];}
for(i=0;i<B;i++)dA[i]=A[i]*(dA[i]-dd);
for(j=0;j<E;j++)dQ[j]=0;
for(i=0;i<B;i++){float sc=dA[i]/sqrtf(E);for(j=0;j<E;j++){dK[i*E+j]=Q[j]*sc;dQ[j]+=K[i*E+j]*sc;}}
for(l=0;l<E;l++){
for(j=0;j<E;j++)gwq[l*E+j]+=X[(B-1)*E+l]*dQ[j];
for(i=0;i<B;i++)for(j=0;j<E;j++){gwk[l*E+j]+=X[i*E+l]*dK[i*E+j];gwv[l*E+j]+=X[i*E+l]*dV[i*E+j];}
}
}
int main(){
srand(time(0));int i,j;
for(i=0;i<V*E;i++)wte[i]=((float)rand()/RAND_MAX-.5f)*.1f;
for(i=0;i<B*E;i++)wpe[i]=((float)rand()/RAND_MAX-.5f)*.1f;
for(i=0;i<E*E;i++){wq[i]=((float)rand()/RAND_MAX-.5f)*.1f;wk[i]=((float)rand()/RAND_MAX-.5f)*.1f;wv[i]=((float)rand()/RAND_MAX-.5f)*.1f;}
for(i=0;i<E*H;i++)w1[i]=((float)rand()/RAND_MAX-.5f)*.1f;
for(i=0;i<H*E;i++)w2[i]=((float)rand()/RAND_MAX-.5f)*.1f;
for(i=0;i<E*V;i++)wh[i]=((float)rand()/RAND_MAX-.5f)*.1f;
FILE*f=fopen("names.txt","r");if(!f){puts("no names.txt");return 1;}
char b[99];N=0;
while(fgets(b,99,f))if(strlen(b)>1)N++;
rewind(f);D=malloc(N*sizeof(char*));int di=0;
while(fgets(b,99,f)){int l=strlen(b);if(b[l-1]=='\n')b[--l]=0;
for(int k=0;k<l;k++)if(b[k]>='A'&&b[k]<='Z')b[k]+=32;
D[di]=malloc(l+3);sprintf(D[di++],".%s.",b);}
fclose(f);
float sl=-1;
printf("Training on %d names...\n",N);
for(int s=0;s<500000;s++){
char*nm=D[rand()%N];int len=strlen(nm);float loss=0;
memset(gwq,0,sizeof(gwq));memset(gwk,0,sizeof(gwk));memset(gwv,0,sizeof(gwv));
memset(gw1,0,sizeof(gw1));memset(gw2,0,sizeof(gw2));memset(gwh,0,sizeof(gwh));
for(int t=0;t<len-1;t++){
int ctx[B],tgt=c2i(nm[t+1]);
for(i=0;i<B;i++){int p=t-B+1+i;ctx[i]=p<0?0:c2i(nm[p]);}
fwd(ctx);
loss+=-logf(P[tgt]+1e-8f);
bwd(tgt);
}
float sc=LR/(len-1);
for(i=0;i<E*E;i++){wq[i]-=sc*gwq[i];wk[i]-=sc*gwk[i];wv[i]-=sc*gwv[i];}
for(i=0;i<E*H;i++)w1[i]-=sc*gw1[i];
for(i=0;i<H*E;i++)w2[i]-=sc*gw2[i];
for(i=0;i<E*V;i++)wh[i]-=sc*gwh[i];
float l2=loss/(len-1);
if(sl<0)sl=l2;else sl=.999f*sl+.001f*l2;
if(s%5000==0)printf("Step %6d | Loss %.4f\n",s,sl);
}
printf("\nType prefix:\n");
while(1){
printf("> ");char in[99];if(!scanf("%s",in))break;
int ctx[B];memset(ctx,0,sizeof(ctx));
int l=strlen(in);for(i=0;i<l&&i<B;i++)ctx[B-l+i]=c2i(in[i]);
printf(" %s",in);
for(int n=0;n<20;n++){
fwd(ctx);
float r=(float)rand()/RAND_MAX,c=0;int nx=0;
for(i=0;i<V;i++){c+=P[i];if(r<=c){nx=i;break;}}
if(!nx)break;
printf("%c",i2c(nx));
for(i=0;i<B-1;i++)ctx[i]=ctx[i+1];ctx[B-1]=nx;
}
printf("\n");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment