Created
August 13, 2025 18:30
-
-
Save theharshith/a4a355c6f20a499f7d5cdd9db020b03a to your computer and use it in GitHub Desktop.
Contrastive Decoding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import transformers as tr | |
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| amateur_path = 'Qwen/Qwen2.5-Coder-0.5B-Instruct' | |
| expert_path = 'Qwen/Qwen2.5-Coder-1.5B-Instruct' #models of the same family | |
| tokenizer = tr.AutoTokenizer.from_pretrained(amateur_path) | |
| user_message = """Give a very very brief docstring for the following function:\n```\nfunction updateEloScores( | |
| scores, | |
| results, | |
| kFactor = 4, | |
| ) { | |
| for (const result of results) { | |
| const { first, second, outcome } = result; | |
| const firstScore = scores[first] ?? 1000; | |
| const secondScore = scores[second] ?? 1000; | |
| const expectedScoreFirst = 1 / (1 + Math.pow(10, (secondScore - firstScore) / 400)); | |
| const expectedScoreSecond = 1 / (1 + Math.pow(10, (firstScore - secondScore) / 400)); | |
| let sa = 0.5; | |
| if (outcome === 1) { | |
| sa = 1; | |
| } else if (outcome === -1) { | |
| sa = 0; | |
| } | |
| scores[first] = firstScore + kFactor * (sa - expectedScoreFirst); | |
| scores[second] = secondScore + kFactor * (1 - sa - expectedScoreSecond); | |
| } | |
| return scores; | |
| }\n```""" | |
| prompt = tokenizer.apply_chat_template( | |
| [ | |
| {'role': 'system', 'content': 'You are a helpful assistant'}, | |
| {'role': 'user', 'content': user_message} | |
| ], | |
| add_generation_prompt=True, | |
| tokenize=False | |
| ) | |
| def contrastive_generation(amateur, expert, prompt, max_tokens) -> str: | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(expert.device) | |
| generated_op = input_ids.clone() | |
| alpha = 0.1 # plausibility threshold | |
| temp_exp = 1.0 | |
| temp_ama = 0.5 # [0.5, 1.0], t<1 -> undesired distibution | |
| for _ in range(max_tokens): | |
| with torch.no_grad(): | |
| expert_outputs = expert(input_ids=generated_op) # (b,vocab_size,seq_len) | |
| expert_logits = expert_outputs.logits[:, -1, :] / temp_exp # (b,vocab_size,seq_len) -> (b,vocab_size) | |
| amateur_outputs = amateur(input_ids=generated_op) | |
| amateur_logits = amateur_outputs.logits[:, -1, :] / temp_ama | |
| # (b,vocab_size) -> (vocab_size) | |
| log_probs_expert = F.log_softmax(expert_logits, dim=-1).squeeze(0) | |
| log_probs_amateur = F.log_softmax(amateur_logits, dim=-1).squeeze(0) | |
| # plausibility constraint | |
| # max_log_exp -> maximum log-prob among expert's logits | |
| # keep logits for those, log_probs_expert > max_log_exp + log(alpha) | |
| # mask True for log_probs_expert >= cutoff | |
| max_log_exp = log_probs_expert.max() | |
| cutoff = max_log_exp + math.log(alpha) | |
| mask = log_probs_expert >= cutoff | |
| feasible_idx = torch.where(mask)[0] | |
| if len(feasible_idx) == 0: | |
| feasible_idx = torch.tensor([log_probs_expert.argmax()], device=expert.device) | |
| # contrast_scores = log_probs_expert - log_probs_amateur | |
| contrast_scores = log_probs_expert[feasible_idx] - log_probs_amateur[feasible_idx] | |
| best_idx = contrast_scores.argmax() | |
| next_token_id = feasible_idx[best_idx].unsqueeze(0) # (1) | |
| generated_op = torch.cat([generated_op, next_token_id.unsqueeze(0)], dim=-1) | |
| # stop if eos token | |
| if next_token_id.item() == tokenizer.eos_token_id: | |
| break | |
| output_text = tokenizer.decode(generated_op[0], skip_special_tokens=True) | |
| return output_text | |
| def test(): | |
| amateur_model = tr.AutoModelForCausalLM.from_pretrained(amateur_path, torch_dtype=torch.float16) | |
| expert_model = tr.AutoModelForCausalLM.from_pretrained(expert_path, torch_dtype=torch.float16) | |
| amateur_model.eval() | |
| expert_model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using device:", device) | |
| amateur_model = amateur_model.to(device) | |
| expert_model = expert_model.to(device) | |
| output = contrastive_generation( | |
| amateur=amateur_model, | |
| expert=expert_model, | |
| prompt=prompt, | |
| max_tokens=512 | |
| ) | |
| print(f"Op: \n\n\n{output.split('assistant\n')[-1].strip()}") | |
| if __name__ == "__main__": | |
| test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment