Created
November 6, 2024 03:52
-
-
Save DSamuelHodge/04bcb4c65ae4b68e45ae3d6f787a9f62 to your computer and use it in GitHub Desktop.
Replaces standard LlamaAttention layers with Differential SPDA Attention layers in a Llama model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import math | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
| from transformers.models.llama.modeling_llama import ( | |
| LlamaAttention, | |
| LlamaRotaryEmbedding, | |
| LlamaForCausalLM, | |
| LlamaConfig, | |
| apply_rotary_pos_emb, | |
| repeat_kv, | |
| Cache | |
| ) | |
| def replace_attention_layers(model: LlamaForCausalLM) -> LlamaForCausalLM: | |
| """ | |
| Replaces standard LlamaAttention layers with Differential SPDA Attention layers in a Llama model. | |
| This function implements the Differential Transformer attention mechanism from Microsoft Research | |
| (MSR-TR-2024-42) by replacing each layer's self-attention with LlamaDiffSdpaAttention. | |
| The function preserves the original weights while changing the attention computation mechanism. | |
| Args: | |
| model (AutoModelForCausalLM): Pre-trained Llama model to be modified | |
| Expected to be loaded from 'meta-llama/Llama-3.2-1B-Instruct' | |
| Returns: | |
| AutoModelForCausalLM: Modified model with differential attention layers | |
| Operation: | |
| 1. Iterates through all model layers | |
| 2. Creates new differential attention layer for each layer | |
| 3. Copies weights from original attention (q_proj, k_proj, v_proj, o_proj) | |
| 4. Replaces original self-attention with differential version | |
| Example: | |
| >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") | |
| >>> model = replace_attention_layers(model) | |
| Notes: | |
| - Maintains model's original weights and parameters | |
| - Compatible with Llama 3.2 1B Instruct model architecture | |
| - Implements differential attention for improved context processing | |
| """ | |
| """Replace standard attention layers with differential attention""" | |
| for i, layer in enumerate(model.model.layers): | |
| # Create new differential attention or `LlamaDiffSdpaAttention` layer | |
| diff_attention = LlamaDiffSdpaAttention( | |
| config=model.config, | |
| layer_idx=i | |
| ) | |
| # Copy weights from original attention layer | |
| diff_attention.q_proj.weight.data = layer.self_attn.q_proj.weight.data | |
| diff_attention.k_proj.weight.data = layer.self_attn.k_proj.weight.data | |
| diff_attention.v_proj.weight.data = layer.self_attn.v_proj.weight.data | |
| diff_attention.o_proj.weight.data = layer.self_attn.o_proj.weight.data | |
| # Replace the attention layer | |
| layer.self_attn = diff_attention | |
| return model | |
| def generate_text( | |
| model: LlamaForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| prompt: str, | |
| max_length: int = 100, | |
| temperature: float = 0.5, | |
| top_p: float = 0.9, | |
| eos_tokens: List[int] = None | |
| ) -> str: | |
| """Generate text using the model""" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| eos_token_id=eos_tokens, | |
| pad_token_id=tokenizer.pad_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def main(): | |
| # Device configuration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load model and tokenizer | |
| model_id = "meta-llama/Llama-3.2-1B-Instruct" | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| print("Replacing attention layers with differential attention...") | |
| model = replace_attention_layers(model) | |
| model.eval() | |
| test_prompts = [ | |
| # Long-context comprehension and information retrieval | |
| """ | |
| Read the following research paper abstract and answer key questions about the methodology and findings: | |
| Title: Impact of Climate Change on Marine Ecosystems | |
| Abstract: Recent studies have demonstrated significant changes in marine ecosystems due to rising ocean temperatures. A 20-year longitudinal study across 50 coral reef sites showed a 45% decline in biodiversity. Temperature increases of 1.5°C corresponded to a 30% reduction in fish populations, while acidification levels rose by 0.1 pH units annually. Researchers utilized satellite data combined with in-situ measurements to track changes in coral bleaching events, finding a 300% increase in frequency over the study period. Economic impacts on local fishing communities were estimated at $2.4 billion annually, affecting approximately 120,000 households across the Pacific region. | |
| Questions: | |
| 1. What was the duration of the study? | |
| 2. How many coral reef sites were monitored? | |
| 3. What was the percentage decline in biodiversity? | |
| 4. What was the economic impact on local communities? | |
| """, | |
| # Hallucination mitigation test with specific factual queries | |
| """ | |
| Based only on the following passage, answer the questions below. If the information isn't provided in the passage, state "Information not provided." | |
| The first electric vehicle was created by Thomas Parker in London in 1884. The vehicle could reach speeds of up to 14 miles per hour and had a range of 25 miles on a single charge. Parker was primarily known for his work in electrifying the London Underground. | |
| Questions: | |
| 1. Who invented the first electric vehicle? | |
| 2. What was the maximum speed of Parker's vehicle? | |
| 3. What was the battery capacity in kilowatt-hours? | |
| 4. Where was Parker born? | |
| 5. What was the vehicle's weight? | |
| """, | |
| # In-context learning with order variation | |
| """ | |
| Learn from these examples and continue the pattern: | |
| Input: "The cat sleeps" | |
| Output: "The feline rests" | |
| Input: "The dog runs" | |
| Output: "The canine sprints" | |
| Input: "The bird flies" | |
| Output: "The avian soars" | |
| Now transform: | |
| Input: "The horse jumps" | |
| Output: ? | |
| """, | |
| # Complex reasoning with noise and distraction | |
| """ | |
| Solve this logic puzzle while ignoring irrelevant information: | |
| Four friends (Alex, Bella, Carlos, and Diana) are sitting at a round table. We know that: | |
| - Alex is wearing a red shirt (though yesterday he wore blue) | |
| - Bella sits directly across from Carlos (who usually prefers to stand) | |
| - Diana sits to the right of someone wearing green (green is her least favorite color) | |
| - The person in blue sits between the person in red and the person in green | |
| - Carlos hates the color yellow, which reminds him of his old car | |
| - It's Tuesday afternoon and slightly cloudy | |
| - Bella is wearing green (she bought it last week at a sale) | |
| - The cafeteria is serving spaghetti today | |
| Question: What color is Diana wearing? | |
| """ | |
| ] | |
| # Generate responses | |
| print("\nGenerating responses...") | |
| for prompt in test_prompts: | |
| print("\nPrompt:", prompt) | |
| print("-" * 50) | |
| response = generate_text( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| max_length=100, | |
| eos_tokens=model.config.eos_token_id | |
| ) | |
| print("Response:", response) | |
| print("=" * 80) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except Exception as e: | |
| print(f"Error occurred: {str(e)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment