Created
June 19, 2025 14:12
-
-
Save btcross26/92d27b13554ae16cb3d33091d9379fbd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class AttentionPool(nn.Module): | |
| def __init__(self, embed_dim): | |
| self.attention = nn.Linear(embed_dim, 1) | |
| def forward(self, x): # x: (batch, seq_len, embed_dim) | |
| scores = self.attention(x) # (batch, seq_len, 1) | |
| weights = F.softmax(scores, dim=1) # (batch, seq_len, 1) | |
| return (weights * x).sum(dim=1) # (batch, embed_dim) | |
| class ReasonCodeTransformer(nn.Module): | |
| def __init__(self): | |
| # Embeddings | |
| self.reason_embeddings = nn.Embedding(160, 32) # 160 reason codes | |
| self.cls_embedding = nn.Parameter(torch.randn(32)) # Learnable CLS token | |
| self.score_projection = nn.Linear(fourier_dim, 32) # Fourier → embed_dim | |
| # Transformer | |
| self.transformer = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer( | |
| d_model=32, | |
| nhead=4, # 32/4 = 8 dim per head | |
| dim_feedforward=128, | |
| dropout=0.1 | |
| ), | |
| num_layers=3 | |
| ) | |
| # Output heads | |
| self.attention_pool = AttentionPool(32) # For non-CLS tokens | |
| self.masking_head = nn.Linear(32, 160) # Reason code prediction | |
| def forward(self, reason_codes, scores, mask): | |
| batch_size = reason_codes.shape[0] | |
| # Create embeddings | |
| reason_embeds = self.reason_embeddings(reason_codes) # (batch, seq, 32) | |
| score_embeds = self.score_projection(fourier_features(scores)) # (batch, 32) | |
| cls_tokens = self.cls_embedding.expand(batch_size, 1, 32) # (batch, 1, 32) | |
| # Concatenate: [CLS, reasons, score] | |
| sequence = torch.cat([cls_tokens, reason_embeds, score_embeds.unsqueeze(1)], dim=1) | |
| # Transformer (no positional encoding) | |
| hidden = self.transformer(sequence.transpose(0,1)).transpose(0,1) # PyTorch format | |
| # Split outputs | |
| cls_output = hidden[:, 0] # For downstream model | |
| other_tokens = hidden[:, 1:] # For masking task | |
| # Masking prediction | |
| pooled = self.attention_pool(other_tokens) | |
| masking_logits = self.masking_head(pooled) | |
| return cls_output, masking_logits | |
| def fourier_embedding(scores, embed_dim=32): | |
| """ | |
| scores: (batch_size,) values between 0 and 1 | |
| returns: (batch_size, embed_dim) fourier features | |
| """ | |
| # Create frequency range | |
| freqs = torch.arange(1, embed_dim//2 + 1, dtype=torch.float32) # [1, 2, ..., 16] | |
| # Scale input for better frequency coverage | |
| scaled_scores = scores.unsqueeze(-1) * freqs * 2 * math.pi # (batch, 16) | |
| # Compute sin and cos | |
| sin_features = torch.sin(scaled_scores) # (batch, 16) | |
| cos_features = torch.cos(scaled_scores) # (batch, 16) | |
| # Concatenate to get 32 dims | |
| return torch.cat([sin_features, cos_features], dim=-1) # (batch, 32) | |
| # In forward(): | |
| score_embeds = fourier_embedding(scores, 32) # (batch, 32) - no linear layer needed! | |
| def parse_reason_codes(code_strings, max_codes=40): | |
| """ | |
| code_strings: list of strings like "5,12,89" or "1,45" or "" | |
| returns: padded tensor of indices | |
| """ | |
| batch_codes = [] | |
| batch_lengths = [] | |
| for code_str in code_strings: | |
| if code_str.strip() == "": | |
| codes = [] | |
| else: | |
| codes = [int(x.strip()) for x in code_str.split(',')] | |
| batch_codes.append(codes[:max_codes]) # Truncate if too long | |
| batch_lengths.append(len(codes)) | |
| # Pad to same length (use 0 as padding, make sure 0 isn't a valid code) | |
| max_len = max(len(codes) for codes in batch_codes) if batch_codes else 0 | |
| padded = torch.zeros(len(batch_codes), max_len, dtype=torch.long) | |
| for i, codes in enumerate(batch_codes): | |
| if codes: | |
| padded[i, :len(codes)] = torch.tensor(codes) | |
| return padded, torch.tensor(batch_lengths) | |
| # Usage: | |
| code_indices, lengths = parse_reason_codes(df['reason_codes'].tolist()) | |
| def create_attention_mask(lengths, max_len, has_cls=True, has_score=True): | |
| """ | |
| lengths: actual number of reason codes per sample | |
| max_len: padded sequence length | |
| """ | |
| batch_size = len(lengths) | |
| # Total sequence = CLS + reason_codes + score | |
| total_len = max_len + (1 if has_cls else 0) + (1 if has_score else 0) | |
| mask = torch.zeros(batch_size, total_len, dtype=torch.bool) | |
| for i, length in enumerate(lengths): | |
| # Mark valid positions as True | |
| start_idx = 1 if has_cls else 0 # Skip CLS if present | |
| mask[i, :start_idx] = True # CLS always valid | |
| mask[i, start_idx:start_idx + length] = True # Valid reason codes | |
| mask[i, -1] = True # Score always valid | |
| return mask | |
| # Usage: | |
| attention_mask = create_attention_mask(lengths, padded_codes.shape[1]) | |
| # PyTorch transformer uses key_padding_mask (True = ignore) | |
| hidden = self.transformer( | |
| sequence.transpose(0,1), | |
| src_key_padding_mask=~attention_mask # Flip because True = ignore | |
| ).transpose(0,1) | |
| # Convert once, save tensors | |
| def preprocess_dataset(df): | |
| codes, lengths = parse_reason_codes(df['reason_codes']) | |
| scores = torch.tensor(df['scores'].values, dtype=torch.float32) | |
| masks = create_attention_mask(lengths, codes.shape[1]) | |
| return TensorDataset(codes, scores, masks, lengths) | |
| dataset = preprocess_dataset(df) | |
| dataloader = DataLoader(dataset, batch_size=512, shuffle=True) | |
| class ReasonCodeDataset(Dataset): | |
| def __init__(self, df): | |
| self.df = df | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| codes = parse_single_code_string(row['reason_codes']) | |
| score = torch.tensor(row['score'], dtype=torch.float32) | |
| return codes, score | |
| def __len__(self): | |
| return len(self.df) | |
| # Custom collate function handles padding per batch | |
| # Save preprocessed tensors | |
| torch.save({ | |
| 'codes': codes, | |
| 'scores': scores, | |
| 'masks': masks, | |
| 'lengths': lengths | |
| }, 'preprocessed_data.pt') | |
| # Load for training | |
| data = torch.load('preprocessed_data.pt') | |
| dataset = TensorDataset(data['codes'], data['scores'], data['masks']) | |
| def masked_bce_loss(predictions, targets, valid_codes_mask): | |
| """ | |
| predictions: (batch, 160) - sigmoid outputs | |
| targets: (batch, 160) - binary targets | |
| valid_codes_mask: (batch, 160) - True where person actually has that code | |
| """ | |
| # Only compute loss on codes the person actually has | |
| loss = F.binary_cross_entropy(predictions, targets, reduction='none') # (batch, 160) | |
| masked_loss = loss * valid_codes_mask.float() # Zero out invalid positions | |
| # Average only over valid positions | |
| return masked_loss.sum() / valid_codes_mask.sum() | |
| # Create targets and mask from your code indices: | |
| def create_targets_and_mask(code_indices, lengths, num_codes=160): | |
| batch_size = code_indices.shape[0] | |
| targets = torch.zeros(batch_size, num_codes) | |
| mask = torch.zeros(batch_size, num_codes, dtype=torch.bool) | |
| for i, length in enumerate(lengths): | |
| valid_codes = code_indices[i, :length] # Remove padding | |
| targets[i, valid_codes] = 1.0 # One-hot encoding | |
| mask[i, valid_codes] = True # Mark as valid for loss | |
| return targets, mask | |
| # Randomly mask some codes for semi-supervised learning | |
| def apply_random_masking(code_indices, lengths, mask_prob=0.15): | |
| # Implementation to randomly hide some codes during forward pass | |
| # But still compute loss on them | |
| pass | |
| def create_contrastive_pairs(batch_codes, batch_lengths, mask_prob=0.15): | |
| batch_size = len(batch_codes) | |
| # Positive pairs: same person, different masked views | |
| pos_view1 = apply_random_masking(batch_codes, batch_lengths, mask_prob) | |
| pos_view2 = apply_random_masking(batch_codes, batch_lengths, mask_prob) | |
| # Negative pairs: different people (shuffle the batch) | |
| neg_indices = torch.randperm(batch_size) | |
| neg_view = batch_codes[neg_indices] | |
| return pos_view1, pos_view2, neg_view | |
| def infonce_loss(anchor, positive, negatives, temperature=0.1): | |
| # anchor, positive: (batch, embed_dim) - your CLS embeddings | |
| # negatives: (batch, embed_dim) | |
| # Cosine similarities | |
| pos_sim = F.cosine_similarity(anchor, positive, dim=-1) / temperature # (batch,) | |
| neg_sim = F.cosine_similarity(anchor.unsqueeze(1), negatives.unsqueeze(0), dim=-1) / temperature # (batch, batch) | |
| # InfoNCE: log(exp(pos) / (exp(pos) + sum(exp(neg)))) | |
| logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) # (batch, batch+1) | |
| labels = torch.zeros(batch_size, dtype=torch.long) # First column is positive | |
| return F.cross_entropy(logits, labels) | |
| # Multi-task loss | |
| recon_loss = masked_bce_loss(masking_predictions, targets, mask) | |
| contrastive_loss = infonce_loss(cls_anchor, cls_positive, cls_negatives) | |
| total_loss = recon_loss + lambda_contrastive * contrastive_loss | |
| class FeatureEmbedder(nn.Module): | |
| def __init__(self): | |
| # Each feature type gets its own pathway to 32-dim space | |
| self.reason_code_net = nn.Sequential( | |
| nn.Embedding(160, 64), | |
| nn.Linear(64, 32), | |
| nn.Tanh() | |
| ) | |
| self.score_net = nn.Sequential( | |
| nn.Linear(32, 64), # From Fourier features | |
| nn.ReLU(), | |
| nn.Linear(64, 32), | |
| nn.Tanh() | |
| ) | |
| self.age_net = nn.Sequential( | |
| nn.Linear(1, 16), | |
| nn.ReLU(), | |
| nn.Linear(16, 32), | |
| nn.Tanh() | |
| ) | |
| self.income_net = nn.Sequential( | |
| nn.Linear(1, 16), | |
| nn.ReLU(), | |
| nn.Linear(16, 32), | |
| nn.Tanh() | |
| ) | |
| # Add a learnable "missing" embedding for each feature type | |
| self.age_missing = nn.Parameter(torch.randn(32)) | |
| self.income_missing = nn.Parameter(torch.randn(32)) | |
| def forward(self, age, age_mask): | |
| age_embed = self.age_net(age) # Normal embedding | |
| age_embed = torch.where(age_mask.unsqueeze(-1), age_embed, self.age_missing) | |
| def preprocess_with_missing(self, value, is_missing): | |
| # Replace missing with 0, add missing flag as extra input | |
| clean_value = torch.where(is_missing, 0.0, value) | |
| missing_flag = is_missing.float() # 0 or 1 | |
| input_with_flag = torch.cat([clean_value.unsqueeze(-1), missing_flag.unsqueeze(-1)], dim=-1) | |
| return self.feature_net(input_with_flag) # Now takes 2 inputs instead of 1 | |
| def forward(self, value, is_missing): | |
| present_embed = self.feature_net(value) | |
| missing_embed = self.missing_net(torch.ones_like(value)) # Constant input | |
| return torch.where(is_missing.unsqueeze(-1), missing_embed, present_embed) | |
| def bjorck_orthogonalize(W, num_iters=3): | |
| for _ in range(num_iters): | |
| W = (3 * W - W @ W.T @ W) / 2 | |
| return W | |
| class BJorckLinear(nn.Module): | |
| def forward(self, x): | |
| W_ortho = bjorck_orthogonalize(self.weight) | |
| return F.linear(x, W_ortho, self.bias) | |
| class UnifiedContinuousEmbedder(nn.Module): | |
| def __init__(self, fourier_dim=32, context_dim=16, output_dim=32): | |
| super().__init__() | |
| # Context embeddings for different feature types | |
| self.feature_type_embedding = nn.Embedding(10, context_dim) # Up to 10 continuous features | |
| # Shared network processes [fourier_features + context] | |
| self.shared_net = nn.Sequential( | |
| nn.Linear(fourier_dim + context_dim, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, output_dim), | |
| nn.Tanh() | |
| ) | |
| # Feature type IDs | |
| self.SCORE_ID = 0 | |
| self.COUNT_ID = 1 | |
| self.AGE_ID = 2 | |
| self.INCOME_ID = 3 | |
| def forward(self, values, feature_types): | |
| # values: continuous values to embed | |
| # feature_types: tensor of feature type IDs | |
| fourier_features = fourier_embedding(values) # (batch, fourier_dim) | |
| context_embeds = self.feature_type_embedding(feature_types) # (batch, context_dim) | |
| # Concatenate and process | |
| combined = torch.cat([fourier_features, context_embeds], dim=-1) | |
| return self.shared_net(combined) | |
| # Usage: | |
| score_embeds = embedder(scores, torch.full_like(scores, embedder.SCORE_ID, dtype=torch.long)) | |
| count_embeds = embedder(counts, torch.full_like(counts, embedder.COUNT_ID, dtype=torch.long)) | |
| class UniversalFeatureEmbedder(nn.Module): | |
| def __init__(self, fourier_dim=16, categorical_dim=16, context_dim=8, output_dim=32): | |
| super().__init__() | |
| # Feature type context | |
| self.feature_type_embedding = nn.Embedding(10, context_dim) | |
| # Shared processor | |
| self.shared_net = nn.Sequential( | |
| nn.Linear(fourier_dim + categorical_dim + context_dim, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, output_dim), | |
| nn.Tanh() | |
| ) | |
| # Raw embeddings | |
| self.reason_code_embedding = nn.Embedding(160, categorical_dim) | |
| # Universal embeddings | |
| self.universal_categorical = nn.Parameter(torch.zeros(categorical_dim)) # For continuous features | |
| self.universal_numerical = nn.Parameter(torch.zeros(fourier_dim)) # For categorical features | |
| def forward(self, values, feature_types, is_categorical_mask): | |
| batch_size = values.shape[0] | |
| # Fourier processing | |
| fourier_features = torch.where( | |
| is_categorical_mask.unsqueeze(-1), | |
| self.universal_numerical.expand(batch_size, -1), # Categorical gets universal | |
| fourier_embedding(values) # Continuous gets Fourier | |
| ) | |
| # Categorical processing | |
| categorical_features = torch.where( | |
| is_categorical_mask.unsqueeze(-1), | |
| self.reason_code_embedding(values.long()), # Categorical gets embedding | |
| self.universal_categorical.expand(batch_size, -1) # Continuous gets universal | |
| ) | |
| # Context | |
| context_embeds = self.feature_type_embedding(feature_types) | |
| # Process through shared network | |
| combined = torch.cat([fourier_features, categorical_features, context_embeds], dim=-1) | |
| return self.shared_net(combined) | |
| class HybridFeatureProcessor(nn.Module): | |
| def __init__(self, fourier_dim=16, categorical_dim=16, context_dim=8, output_dim=32): | |
| super().__init__() | |
| # Processing pathways | |
| self.fourier_processor = nn.Linear(fourier_dim, 24) | |
| self.categorical_processor = nn.Embedding(1000, categorical_dim) # For special codes | |
| self.context_processor = nn.Embedding(10, context_dim) | |
| # Unified output | |
| self.output_net = nn.Sequential( | |
| nn.Linear(24 + categorical_dim + context_dim, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, output_dim), | |
| nn.Tanh() | |
| ) | |
| # Define special codes per feature type | |
| self.special_codes = { | |
| 'score': [999, 999999, -1], | |
| 'age': [999, -1], | |
| 'income': [999999, -999, 0] | |
| } | |
| def forward(self, values, feature_type_name, feature_type_id): | |
| special_codes = self.special_codes[feature_type_name] | |
| is_special = torch.isin(values, torch.tensor(special_codes)) | |
| # Route to appropriate processor | |
| fourier_out = torch.where( | |
| is_special.unsqueeze(-1), | |
| torch.zeros(values.shape[0], 24), # Special codes get zeros | |
| self.fourier_processor(fourier_embedding(values)) | |
| ) | |
| categorical_out = torch.where( | |
| is_special.unsqueeze(-1), | |
| self.categorical_processor(values.long()), # Special codes get embeddings | |
| torch.zeros(values.shape[0], categorical_dim) | |
| ) | |
| context_out = self.context_processor(feature_type_id) | |
| return self.output_net(torch.cat([fourier_out, categorical_out, context_out], dim=-1)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment