Skip to content

Instantly share code, notes, and snippets.

@theharshith
Created August 13, 2025 18:30
Show Gist options
  • Select an option

  • Save theharshith/a4a355c6f20a499f7d5cdd9db020b03a to your computer and use it in GitHub Desktop.

Select an option

Save theharshith/a4a355c6f20a499f7d5cdd9db020b03a to your computer and use it in GitHub Desktop.
Contrastive Decoding
import transformers as tr
import torch
import torch.nn.functional as F
import math
amateur_path = 'Qwen/Qwen2.5-Coder-0.5B-Instruct'
expert_path = 'Qwen/Qwen2.5-Coder-1.5B-Instruct' #models of the same family
tokenizer = tr.AutoTokenizer.from_pretrained(amateur_path)
user_message = """Give a very very brief docstring for the following function:\n```\nfunction updateEloScores(
scores,
results,
kFactor = 4,
) {
for (const result of results) {
const { first, second, outcome } = result;
const firstScore = scores[first] ?? 1000;
const secondScore = scores[second] ?? 1000;
const expectedScoreFirst = 1 / (1 + Math.pow(10, (secondScore - firstScore) / 400));
const expectedScoreSecond = 1 / (1 + Math.pow(10, (firstScore - secondScore) / 400));
let sa = 0.5;
if (outcome === 1) {
sa = 1;
} else if (outcome === -1) {
sa = 0;
}
scores[first] = firstScore + kFactor * (sa - expectedScoreFirst);
scores[second] = secondScore + kFactor * (1 - sa - expectedScoreSecond);
}
return scores;
}\n```"""
prompt = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': 'You are a helpful assistant'},
{'role': 'user', 'content': user_message}
],
add_generation_prompt=True,
tokenize=False
)
def contrastive_generation(amateur, expert, prompt, max_tokens) -> str:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(expert.device)
generated_op = input_ids.clone()
alpha = 0.1 # plausibility threshold
temp_exp = 1.0
temp_ama = 0.5 # [0.5, 1.0], t<1 -> undesired distibution
for _ in range(max_tokens):
with torch.no_grad():
expert_outputs = expert(input_ids=generated_op) # (b,vocab_size,seq_len)
expert_logits = expert_outputs.logits[:, -1, :] / temp_exp # (b,vocab_size,seq_len) -> (b,vocab_size)
amateur_outputs = amateur(input_ids=generated_op)
amateur_logits = amateur_outputs.logits[:, -1, :] / temp_ama
# (b,vocab_size) -> (vocab_size)
log_probs_expert = F.log_softmax(expert_logits, dim=-1).squeeze(0)
log_probs_amateur = F.log_softmax(amateur_logits, dim=-1).squeeze(0)
# plausibility constraint
# max_log_exp -> maximum log-prob among expert's logits
# keep logits for those, log_probs_expert > max_log_exp + log(alpha)
# mask True for log_probs_expert >= cutoff
max_log_exp = log_probs_expert.max()
cutoff = max_log_exp + math.log(alpha)
mask = log_probs_expert >= cutoff
feasible_idx = torch.where(mask)[0]
if len(feasible_idx) == 0:
feasible_idx = torch.tensor([log_probs_expert.argmax()], device=expert.device)
# contrast_scores = log_probs_expert - log_probs_amateur
contrast_scores = log_probs_expert[feasible_idx] - log_probs_amateur[feasible_idx]
best_idx = contrast_scores.argmax()
next_token_id = feasible_idx[best_idx].unsqueeze(0) # (1)
generated_op = torch.cat([generated_op, next_token_id.unsqueeze(0)], dim=-1)
# stop if eos token
if next_token_id.item() == tokenizer.eos_token_id:
break
output_text = tokenizer.decode(generated_op[0], skip_special_tokens=True)
return output_text
def test():
amateur_model = tr.AutoModelForCausalLM.from_pretrained(amateur_path, torch_dtype=torch.float16)
expert_model = tr.AutoModelForCausalLM.from_pretrained(expert_path, torch_dtype=torch.float16)
amateur_model.eval()
expert_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
amateur_model = amateur_model.to(device)
expert_model = expert_model.to(device)
output = contrastive_generation(
amateur=amateur_model,
expert=expert_model,
prompt=prompt,
max_tokens=512
)
print(f"Op: \n\n\n{output.split('assistant\n')[-1].strip()}")
if __name__ == "__main__":
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment