Created
November 4, 2025 23:26
-
-
Save falseywinchnet/beac838b5446298abd240e5422ff57c8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class CayleyDicksonEmbedding(nn.Module): | |
| def __init__(self, num_embeddings: int, base_dim: int = 1, lifts: int = 3): | |
| """ | |
| num_embeddings : number of unique indices | |
| base_dim : dimension of the seed embedding (usually 1) | |
| lifts : number of Cayley–Dickson doublings (output dim = base_dim * 2**lifts) | |
| """ | |
| super().__init__() | |
| self.num_embeddings = num_embeddings | |
| self.base_dim = base_dim | |
| self.lifts = lifts | |
| # base embedding table | |
| self.seed = nn.Embedding(num_embeddings, base_dim) | |
| def cayley_dickson_pair(self, a, b): | |
| """ | |
| perform a single Cayley–Dickson multiplication step element-wise | |
| (a,b)·(c,d) -> (ac−bd, ad+bc) | |
| here we just form the paired structure, not multiplying distinct vectors | |
| """ | |
| ac = a * a | |
| bd = b * b | |
| ad = a * b | |
| bc = b * a | |
| return torch.cat([ac - bd, ad + bc], dim=-1) | |
| def forward(self, idx): | |
| """ | |
| recursively apply CD lifts to obtain a structured embedding | |
| """ | |
| x = self.seed(idx) # (batch, base_dim) | |
| for i in range(self.lifts): | |
| # each lift uses a new embedding table of same size | |
| table = nn.Embedding(self.num_embeddings, self.base_dim) | |
| if self.training: | |
| table = table.to(x.device) | |
| y = table(idx) | |
| x = self.cayley_dickson_pair(x, y) | |
| return F.normalize(x, dim=-1) | |
| # example usage | |
| num_items = 10 | |
| cd_embed = CayleyDicksonEmbedding(num_items, base_dim=1, lifts=3) | |
| idx = torch.arange(num_items) | |
| out = cd_embed(idx) | |
| print(out.shape) # -> torch.Size([10, 8]) because 1 * 2**3 = 8 | |
| print(out[:3]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment