Skip to content

Instantly share code, notes, and snippets.

@falseywinchnet
Created November 4, 2025 23:26
Show Gist options
  • Select an option

  • Save falseywinchnet/beac838b5446298abd240e5422ff57c8 to your computer and use it in GitHub Desktop.

Select an option

Save falseywinchnet/beac838b5446298abd240e5422ff57c8 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class CayleyDicksonEmbedding(nn.Module):
def __init__(self, num_embeddings: int, base_dim: int = 1, lifts: int = 3):
"""
num_embeddings : number of unique indices
base_dim : dimension of the seed embedding (usually 1)
lifts : number of Cayley–Dickson doublings (output dim = base_dim * 2**lifts)
"""
super().__init__()
self.num_embeddings = num_embeddings
self.base_dim = base_dim
self.lifts = lifts
# base embedding table
self.seed = nn.Embedding(num_embeddings, base_dim)
def cayley_dickson_pair(self, a, b):
"""
perform a single Cayley–Dickson multiplication step element-wise
(a,b)·(c,d) -> (ac−bd, ad+bc)
here we just form the paired structure, not multiplying distinct vectors
"""
ac = a * a
bd = b * b
ad = a * b
bc = b * a
return torch.cat([ac - bd, ad + bc], dim=-1)
def forward(self, idx):
"""
recursively apply CD lifts to obtain a structured embedding
"""
x = self.seed(idx) # (batch, base_dim)
for i in range(self.lifts):
# each lift uses a new embedding table of same size
table = nn.Embedding(self.num_embeddings, self.base_dim)
if self.training:
table = table.to(x.device)
y = table(idx)
x = self.cayley_dickson_pair(x, y)
return F.normalize(x, dim=-1)
# example usage
num_items = 10
cd_embed = CayleyDicksonEmbedding(num_items, base_dim=1, lifts=3)
idx = torch.arange(num_items)
out = cd_embed(idx)
print(out.shape) # -> torch.Size([10, 8]) because 1 * 2**3 = 8
print(out[:3])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment