Skip to content

Instantly share code, notes, and snippets.

x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
x1 = torch.randn(5)
x2 = x1.clone()
x1.requires_grad = True
x2.requires_grad = True
alpha = 1 - 0.5
beta = 1 + 0.53
@joey00072
joey00072 / mla.py
Created December 28, 2024 16:25
multi head latent attention (MLA)
# https://x.com/shxf0072/status/1873038335427658011
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from collections import OrderedDict
from ohara.modules.norm import RMSNorm