Skip to content

Instantly share code, notes, and snippets.

@btcross26
Created June 19, 2025 14:12
Show Gist options
  • Select an option

  • Save btcross26/92d27b13554ae16cb3d33091d9379fbd to your computer and use it in GitHub Desktop.

Select an option

Save btcross26/92d27b13554ae16cb3d33091d9379fbd to your computer and use it in GitHub Desktop.
class AttentionPool(nn.Module):
def __init__(self, embed_dim):
self.attention = nn.Linear(embed_dim, 1)
def forward(self, x): # x: (batch, seq_len, embed_dim)
scores = self.attention(x) # (batch, seq_len, 1)
weights = F.softmax(scores, dim=1) # (batch, seq_len, 1)
return (weights * x).sum(dim=1) # (batch, embed_dim)
class ReasonCodeTransformer(nn.Module):
def __init__(self):
# Embeddings
self.reason_embeddings = nn.Embedding(160, 32) # 160 reason codes
self.cls_embedding = nn.Parameter(torch.randn(32)) # Learnable CLS token
self.score_projection = nn.Linear(fourier_dim, 32) # Fourier → embed_dim
# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=32,
nhead=4, # 32/4 = 8 dim per head
dim_feedforward=128,
dropout=0.1
),
num_layers=3
)
# Output heads
self.attention_pool = AttentionPool(32) # For non-CLS tokens
self.masking_head = nn.Linear(32, 160) # Reason code prediction
def forward(self, reason_codes, scores, mask):
batch_size = reason_codes.shape[0]
# Create embeddings
reason_embeds = self.reason_embeddings(reason_codes) # (batch, seq, 32)
score_embeds = self.score_projection(fourier_features(scores)) # (batch, 32)
cls_tokens = self.cls_embedding.expand(batch_size, 1, 32) # (batch, 1, 32)
# Concatenate: [CLS, reasons, score]
sequence = torch.cat([cls_tokens, reason_embeds, score_embeds.unsqueeze(1)], dim=1)
# Transformer (no positional encoding)
hidden = self.transformer(sequence.transpose(0,1)).transpose(0,1) # PyTorch format
# Split outputs
cls_output = hidden[:, 0] # For downstream model
other_tokens = hidden[:, 1:] # For masking task
# Masking prediction
pooled = self.attention_pool(other_tokens)
masking_logits = self.masking_head(pooled)
return cls_output, masking_logits
def fourier_embedding(scores, embed_dim=32):
"""
scores: (batch_size,) values between 0 and 1
returns: (batch_size, embed_dim) fourier features
"""
# Create frequency range
freqs = torch.arange(1, embed_dim//2 + 1, dtype=torch.float32) # [1, 2, ..., 16]
# Scale input for better frequency coverage
scaled_scores = scores.unsqueeze(-1) * freqs * 2 * math.pi # (batch, 16)
# Compute sin and cos
sin_features = torch.sin(scaled_scores) # (batch, 16)
cos_features = torch.cos(scaled_scores) # (batch, 16)
# Concatenate to get 32 dims
return torch.cat([sin_features, cos_features], dim=-1) # (batch, 32)
# In forward():
score_embeds = fourier_embedding(scores, 32) # (batch, 32) - no linear layer needed!
def parse_reason_codes(code_strings, max_codes=40):
"""
code_strings: list of strings like "5,12,89" or "1,45" or ""
returns: padded tensor of indices
"""
batch_codes = []
batch_lengths = []
for code_str in code_strings:
if code_str.strip() == "":
codes = []
else:
codes = [int(x.strip()) for x in code_str.split(',')]
batch_codes.append(codes[:max_codes]) # Truncate if too long
batch_lengths.append(len(codes))
# Pad to same length (use 0 as padding, make sure 0 isn't a valid code)
max_len = max(len(codes) for codes in batch_codes) if batch_codes else 0
padded = torch.zeros(len(batch_codes), max_len, dtype=torch.long)
for i, codes in enumerate(batch_codes):
if codes:
padded[i, :len(codes)] = torch.tensor(codes)
return padded, torch.tensor(batch_lengths)
# Usage:
code_indices, lengths = parse_reason_codes(df['reason_codes'].tolist())
def create_attention_mask(lengths, max_len, has_cls=True, has_score=True):
"""
lengths: actual number of reason codes per sample
max_len: padded sequence length
"""
batch_size = len(lengths)
# Total sequence = CLS + reason_codes + score
total_len = max_len + (1 if has_cls else 0) + (1 if has_score else 0)
mask = torch.zeros(batch_size, total_len, dtype=torch.bool)
for i, length in enumerate(lengths):
# Mark valid positions as True
start_idx = 1 if has_cls else 0 # Skip CLS if present
mask[i, :start_idx] = True # CLS always valid
mask[i, start_idx:start_idx + length] = True # Valid reason codes
mask[i, -1] = True # Score always valid
return mask
# Usage:
attention_mask = create_attention_mask(lengths, padded_codes.shape[1])
# PyTorch transformer uses key_padding_mask (True = ignore)
hidden = self.transformer(
sequence.transpose(0,1),
src_key_padding_mask=~attention_mask # Flip because True = ignore
).transpose(0,1)
# Convert once, save tensors
def preprocess_dataset(df):
codes, lengths = parse_reason_codes(df['reason_codes'])
scores = torch.tensor(df['scores'].values, dtype=torch.float32)
masks = create_attention_mask(lengths, codes.shape[1])
return TensorDataset(codes, scores, masks, lengths)
dataset = preprocess_dataset(df)
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)
class ReasonCodeDataset(Dataset):
def __init__(self, df):
self.df = df
def __getitem__(self, idx):
row = self.df.iloc[idx]
codes = parse_single_code_string(row['reason_codes'])
score = torch.tensor(row['score'], dtype=torch.float32)
return codes, score
def __len__(self):
return len(self.df)
# Custom collate function handles padding per batch
# Save preprocessed tensors
torch.save({
'codes': codes,
'scores': scores,
'masks': masks,
'lengths': lengths
}, 'preprocessed_data.pt')
# Load for training
data = torch.load('preprocessed_data.pt')
dataset = TensorDataset(data['codes'], data['scores'], data['masks'])
def masked_bce_loss(predictions, targets, valid_codes_mask):
"""
predictions: (batch, 160) - sigmoid outputs
targets: (batch, 160) - binary targets
valid_codes_mask: (batch, 160) - True where person actually has that code
"""
# Only compute loss on codes the person actually has
loss = F.binary_cross_entropy(predictions, targets, reduction='none') # (batch, 160)
masked_loss = loss * valid_codes_mask.float() # Zero out invalid positions
# Average only over valid positions
return masked_loss.sum() / valid_codes_mask.sum()
# Create targets and mask from your code indices:
def create_targets_and_mask(code_indices, lengths, num_codes=160):
batch_size = code_indices.shape[0]
targets = torch.zeros(batch_size, num_codes)
mask = torch.zeros(batch_size, num_codes, dtype=torch.bool)
for i, length in enumerate(lengths):
valid_codes = code_indices[i, :length] # Remove padding
targets[i, valid_codes] = 1.0 # One-hot encoding
mask[i, valid_codes] = True # Mark as valid for loss
return targets, mask
# Randomly mask some codes for semi-supervised learning
def apply_random_masking(code_indices, lengths, mask_prob=0.15):
# Implementation to randomly hide some codes during forward pass
# But still compute loss on them
pass
def create_contrastive_pairs(batch_codes, batch_lengths, mask_prob=0.15):
batch_size = len(batch_codes)
# Positive pairs: same person, different masked views
pos_view1 = apply_random_masking(batch_codes, batch_lengths, mask_prob)
pos_view2 = apply_random_masking(batch_codes, batch_lengths, mask_prob)
# Negative pairs: different people (shuffle the batch)
neg_indices = torch.randperm(batch_size)
neg_view = batch_codes[neg_indices]
return pos_view1, pos_view2, neg_view
def infonce_loss(anchor, positive, negatives, temperature=0.1):
# anchor, positive: (batch, embed_dim) - your CLS embeddings
# negatives: (batch, embed_dim)
# Cosine similarities
pos_sim = F.cosine_similarity(anchor, positive, dim=-1) / temperature # (batch,)
neg_sim = F.cosine_similarity(anchor.unsqueeze(1), negatives.unsqueeze(0), dim=-1) / temperature # (batch, batch)
# InfoNCE: log(exp(pos) / (exp(pos) + sum(exp(neg))))
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) # (batch, batch+1)
labels = torch.zeros(batch_size, dtype=torch.long) # First column is positive
return F.cross_entropy(logits, labels)
# Multi-task loss
recon_loss = masked_bce_loss(masking_predictions, targets, mask)
contrastive_loss = infonce_loss(cls_anchor, cls_positive, cls_negatives)
total_loss = recon_loss + lambda_contrastive * contrastive_loss
class FeatureEmbedder(nn.Module):
def __init__(self):
# Each feature type gets its own pathway to 32-dim space
self.reason_code_net = nn.Sequential(
nn.Embedding(160, 64),
nn.Linear(64, 32),
nn.Tanh()
)
self.score_net = nn.Sequential(
nn.Linear(32, 64), # From Fourier features
nn.ReLU(),
nn.Linear(64, 32),
nn.Tanh()
)
self.age_net = nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
nn.Linear(16, 32),
nn.Tanh()
)
self.income_net = nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
nn.Linear(16, 32),
nn.Tanh()
)
# Add a learnable "missing" embedding for each feature type
self.age_missing = nn.Parameter(torch.randn(32))
self.income_missing = nn.Parameter(torch.randn(32))
def forward(self, age, age_mask):
age_embed = self.age_net(age) # Normal embedding
age_embed = torch.where(age_mask.unsqueeze(-1), age_embed, self.age_missing)
def preprocess_with_missing(self, value, is_missing):
# Replace missing with 0, add missing flag as extra input
clean_value = torch.where(is_missing, 0.0, value)
missing_flag = is_missing.float() # 0 or 1
input_with_flag = torch.cat([clean_value.unsqueeze(-1), missing_flag.unsqueeze(-1)], dim=-1)
return self.feature_net(input_with_flag) # Now takes 2 inputs instead of 1
def forward(self, value, is_missing):
present_embed = self.feature_net(value)
missing_embed = self.missing_net(torch.ones_like(value)) # Constant input
return torch.where(is_missing.unsqueeze(-1), missing_embed, present_embed)
def bjorck_orthogonalize(W, num_iters=3):
for _ in range(num_iters):
W = (3 * W - W @ W.T @ W) / 2
return W
class BJorckLinear(nn.Module):
def forward(self, x):
W_ortho = bjorck_orthogonalize(self.weight)
return F.linear(x, W_ortho, self.bias)
class UnifiedContinuousEmbedder(nn.Module):
def __init__(self, fourier_dim=32, context_dim=16, output_dim=32):
super().__init__()
# Context embeddings for different feature types
self.feature_type_embedding = nn.Embedding(10, context_dim) # Up to 10 continuous features
# Shared network processes [fourier_features + context]
self.shared_net = nn.Sequential(
nn.Linear(fourier_dim + context_dim, 64),
nn.ReLU(),
nn.Linear(64, output_dim),
nn.Tanh()
)
# Feature type IDs
self.SCORE_ID = 0
self.COUNT_ID = 1
self.AGE_ID = 2
self.INCOME_ID = 3
def forward(self, values, feature_types):
# values: continuous values to embed
# feature_types: tensor of feature type IDs
fourier_features = fourier_embedding(values) # (batch, fourier_dim)
context_embeds = self.feature_type_embedding(feature_types) # (batch, context_dim)
# Concatenate and process
combined = torch.cat([fourier_features, context_embeds], dim=-1)
return self.shared_net(combined)
# Usage:
score_embeds = embedder(scores, torch.full_like(scores, embedder.SCORE_ID, dtype=torch.long))
count_embeds = embedder(counts, torch.full_like(counts, embedder.COUNT_ID, dtype=torch.long))
class UniversalFeatureEmbedder(nn.Module):
def __init__(self, fourier_dim=16, categorical_dim=16, context_dim=8, output_dim=32):
super().__init__()
# Feature type context
self.feature_type_embedding = nn.Embedding(10, context_dim)
# Shared processor
self.shared_net = nn.Sequential(
nn.Linear(fourier_dim + categorical_dim + context_dim, 64),
nn.ReLU(),
nn.Linear(64, output_dim),
nn.Tanh()
)
# Raw embeddings
self.reason_code_embedding = nn.Embedding(160, categorical_dim)
# Universal embeddings
self.universal_categorical = nn.Parameter(torch.zeros(categorical_dim)) # For continuous features
self.universal_numerical = nn.Parameter(torch.zeros(fourier_dim)) # For categorical features
def forward(self, values, feature_types, is_categorical_mask):
batch_size = values.shape[0]
# Fourier processing
fourier_features = torch.where(
is_categorical_mask.unsqueeze(-1),
self.universal_numerical.expand(batch_size, -1), # Categorical gets universal
fourier_embedding(values) # Continuous gets Fourier
)
# Categorical processing
categorical_features = torch.where(
is_categorical_mask.unsqueeze(-1),
self.reason_code_embedding(values.long()), # Categorical gets embedding
self.universal_categorical.expand(batch_size, -1) # Continuous gets universal
)
# Context
context_embeds = self.feature_type_embedding(feature_types)
# Process through shared network
combined = torch.cat([fourier_features, categorical_features, context_embeds], dim=-1)
return self.shared_net(combined)
class HybridFeatureProcessor(nn.Module):
def __init__(self, fourier_dim=16, categorical_dim=16, context_dim=8, output_dim=32):
super().__init__()
# Processing pathways
self.fourier_processor = nn.Linear(fourier_dim, 24)
self.categorical_processor = nn.Embedding(1000, categorical_dim) # For special codes
self.context_processor = nn.Embedding(10, context_dim)
# Unified output
self.output_net = nn.Sequential(
nn.Linear(24 + categorical_dim + context_dim, 64),
nn.ReLU(),
nn.Linear(64, output_dim),
nn.Tanh()
)
# Define special codes per feature type
self.special_codes = {
'score': [999, 999999, -1],
'age': [999, -1],
'income': [999999, -999, 0]
}
def forward(self, values, feature_type_name, feature_type_id):
special_codes = self.special_codes[feature_type_name]
is_special = torch.isin(values, torch.tensor(special_codes))
# Route to appropriate processor
fourier_out = torch.where(
is_special.unsqueeze(-1),
torch.zeros(values.shape[0], 24), # Special codes get zeros
self.fourier_processor(fourier_embedding(values))
)
categorical_out = torch.where(
is_special.unsqueeze(-1),
self.categorical_processor(values.long()), # Special codes get embeddings
torch.zeros(values.shape[0], categorical_dim)
)
context_out = self.context_processor(feature_type_id)
return self.output_net(torch.cat([fourier_out, categorical_out, context_out], dim=-1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment