Skip to content

Instantly share code, notes, and snippets.

@DSamuelHodge
Created November 6, 2024 03:52
Show Gist options
  • Select an option

  • Save DSamuelHodge/04bcb4c65ae4b68e45ae3d6f787a9f62 to your computer and use it in GitHub Desktop.

Select an option

Save DSamuelHodge/04bcb4c65ae4b68e45ae3d6f787a9f62 to your computer and use it in GitHub Desktop.
Replaces standard LlamaAttention layers with Differential SPDA Attention layers in a Llama model.
import os
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaRotaryEmbedding,
LlamaForCausalLM,
LlamaConfig,
apply_rotary_pos_emb,
repeat_kv,
Cache
)
def replace_attention_layers(model: LlamaForCausalLM) -> LlamaForCausalLM:
"""
Replaces standard LlamaAttention layers with Differential SPDA Attention layers in a Llama model.
This function implements the Differential Transformer attention mechanism from Microsoft Research
(MSR-TR-2024-42) by replacing each layer's self-attention with LlamaDiffSdpaAttention.
The function preserves the original weights while changing the attention computation mechanism.
Args:
model (AutoModelForCausalLM): Pre-trained Llama model to be modified
Expected to be loaded from 'meta-llama/Llama-3.2-1B-Instruct'
Returns:
AutoModelForCausalLM: Modified model with differential attention layers
Operation:
1. Iterates through all model layers
2. Creates new differential attention layer for each layer
3. Copies weights from original attention (q_proj, k_proj, v_proj, o_proj)
4. Replaces original self-attention with differential version
Example:
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
>>> model = replace_attention_layers(model)
Notes:
- Maintains model's original weights and parameters
- Compatible with Llama 3.2 1B Instruct model architecture
- Implements differential attention for improved context processing
"""
"""Replace standard attention layers with differential attention"""
for i, layer in enumerate(model.model.layers):
# Create new differential attention or `LlamaDiffSdpaAttention` layer
diff_attention = LlamaDiffSdpaAttention(
config=model.config,
layer_idx=i
)
# Copy weights from original attention layer
diff_attention.q_proj.weight.data = layer.self_attn.q_proj.weight.data
diff_attention.k_proj.weight.data = layer.self_attn.k_proj.weight.data
diff_attention.v_proj.weight.data = layer.self_attn.v_proj.weight.data
diff_attention.o_proj.weight.data = layer.self_attn.o_proj.weight.data
# Replace the attention layer
layer.self_attn = diff_attention
return model
def generate_text(
model: LlamaForCausalLM,
tokenizer: AutoTokenizer,
prompt: str,
max_length: int = 100,
temperature: float = 0.5,
top_p: float = 0.9,
eos_tokens: List[int] = None
) -> str:
"""Generate text using the model"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
eos_token_id=eos_tokens,
pad_token_id=tokenizer.pad_token_id,
repetition_penalty=1.1
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def main():
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model and tokenizer
model_id = "meta-llama/Llama-3.2-1B-Instruct"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("Replacing attention layers with differential attention...")
model = replace_attention_layers(model)
model.eval()
test_prompts = [
# Long-context comprehension and information retrieval
"""
Read the following research paper abstract and answer key questions about the methodology and findings:
Title: Impact of Climate Change on Marine Ecosystems
Abstract: Recent studies have demonstrated significant changes in marine ecosystems due to rising ocean temperatures. A 20-year longitudinal study across 50 coral reef sites showed a 45% decline in biodiversity. Temperature increases of 1.5°C corresponded to a 30% reduction in fish populations, while acidification levels rose by 0.1 pH units annually. Researchers utilized satellite data combined with in-situ measurements to track changes in coral bleaching events, finding a 300% increase in frequency over the study period. Economic impacts on local fishing communities were estimated at $2.4 billion annually, affecting approximately 120,000 households across the Pacific region.
Questions:
1. What was the duration of the study?
2. How many coral reef sites were monitored?
3. What was the percentage decline in biodiversity?
4. What was the economic impact on local communities?
""",
# Hallucination mitigation test with specific factual queries
"""
Based only on the following passage, answer the questions below. If the information isn't provided in the passage, state "Information not provided."
The first electric vehicle was created by Thomas Parker in London in 1884. The vehicle could reach speeds of up to 14 miles per hour and had a range of 25 miles on a single charge. Parker was primarily known for his work in electrifying the London Underground.
Questions:
1. Who invented the first electric vehicle?
2. What was the maximum speed of Parker's vehicle?
3. What was the battery capacity in kilowatt-hours?
4. Where was Parker born?
5. What was the vehicle's weight?
""",
# In-context learning with order variation
"""
Learn from these examples and continue the pattern:
Input: "The cat sleeps"
Output: "The feline rests"
Input: "The dog runs"
Output: "The canine sprints"
Input: "The bird flies"
Output: "The avian soars"
Now transform:
Input: "The horse jumps"
Output: ?
""",
# Complex reasoning with noise and distraction
"""
Solve this logic puzzle while ignoring irrelevant information:
Four friends (Alex, Bella, Carlos, and Diana) are sitting at a round table. We know that:
- Alex is wearing a red shirt (though yesterday he wore blue)
- Bella sits directly across from Carlos (who usually prefers to stand)
- Diana sits to the right of someone wearing green (green is her least favorite color)
- The person in blue sits between the person in red and the person in green
- Carlos hates the color yellow, which reminds him of his old car
- It's Tuesday afternoon and slightly cloudy
- Bella is wearing green (she bought it last week at a sale)
- The cafeteria is serving spaghetti today
Question: What color is Diana wearing?
"""
]
# Generate responses
print("\nGenerating responses...")
for prompt in test_prompts:
print("\nPrompt:", prompt)
print("-" * 50)
response = generate_text(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_length=100,
eos_tokens=model.config.eos_token_id
)
print("Response:", response)
print("=" * 80)
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"Error occurred: {str(e)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment