Skip to content

Instantly share code, notes, and snippets.

@dhruvilp
Last active October 13, 2025 03:22
Show Gist options
  • Select an option

  • Save dhruvilp/7160cfcfa1a14e3e2c8aa23345673863 to your computer and use it in GitHub Desktop.

Select an option

Save dhruvilp/7160cfcfa1a14e3e2c8aa23345673863 to your computer and use it in GitHub Desktop.
gpt-oss-20b-ft-lora-sample-code
import json
import pandas as pd
from langchain.text_splitter import RecursiveCharacterTextSplitter
from openai import OpenAI
import os
from typing import List, Dict
import random
class SyntheticDataGenerator:
def __init__(self, api_key: str = None, model: str = "gpt-4"):
"""
Initialize the synthetic data generator.
Args:
api_key: OpenAI API key (if None, reads from OPENAI_API_KEY env var)
model: LLM model to use for generation
"""
self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
self.model = model
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
def load_csv(self, csv_path: str) -> pd.DataFrame:
"""
Load CSV file into DataFrame.
Args:
csv_path: Path to the CSV file
Returns:
DataFrame containing the CSV data
"""
df = pd.read_csv(csv_path)
# Validate required columns
required_columns = ['context']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise ValueError(f"CSV missing required columns: {missing_columns}")
print(f"Loaded CSV with {len(df)} rows")
print(f"Columns: {list(df.columns)}")
return df
def split_text(self, text: str) -> List[str]:
"""
Split text into chunks using RecursiveCharacterTextSplitter.
Args:
text: Input text to split
Returns:
List of text chunks
"""
chunks = self.text_splitter.split_text(text)
return chunks
def generate_qna_from_chunk(self, chunk: str, row_metadata: Dict, chunk_id: str) -> Dict:
"""
Generate QnA pair with contexts from a text chunk using LLM.
Args:
chunk: Text chunk to generate QnA from
row_metadata: Metadata from the CSV row (ticker, filing, etc.)
chunk_id: Unique ID for this chunk
Returns:
Dictionary containing question, answers, and contexts
"""
prompt = f"""Based on the following text, generate a synthetic question-answer pair with supporting and non-supporting context documents.
Text:
{chunk}
Generate a JSON object with this EXACT structure:
{{
"question": "A specific question based on the text",
"answers": ["A detailed answer to the question"],
"supporting_contexts": [
{{
"title": "Relevant title",
"text": "Text snippet that contains the answer (extract or paraphrase from source)"
}},
{{
"title": "Another relevant title",
"text": "Another text snippet that supports the answer"
}},
{{
"title": "Third relevant title",
"text": "Third text snippet that supports the answer"
}}
],
"non_supporting_contexts": [
{{
"title": "Related but different topic title",
"text": "Text about a related but different topic that doesn't answer the question"
}},
{{
"title": "Another unrelated title",
"text": "Another text that doesn't contain the answer"
}},
{{
"title": "Third unrelated title",
"text": "Third text that doesn't answer the question"
}},
{{
"title": "Fourth unrelated title",
"text": "Fourth text that doesn't contain the answer"
}}
]
}}
Requirements:
- Question should be specific and answerable from the text
- Generate exactly 3 supporting contexts (has_answer: True)
- Generate exactly 4 non-supporting contexts (has_answer: False)
- Non-supporting contexts should be plausible but not contain the answer
- All contexts should seem realistic
Return ONLY the JSON object, no additional text."""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant that generates structured synthetic training data for question-answering systems. Always respond with valid JSON only."},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=2000
)
response_text = response.choices[0].message.content.strip()
# Try to extract JSON if wrapped in markdown code blocks
if response_text.startswith("```"):
response_text = response_text.split("```")[1]
if response_text.startswith("json"):
response_text = response_text[4:]
response_text = response_text.strip()
data = json.loads(response_text)
# Format the data to match the desired structure
formatted_data = {
"question": data["question"],
"answers": data["answers"],
"ctxs": []
}
# Add metadata from CSV row if available
if row_metadata:
formatted_data["metadata"] = row_metadata
# Add supporting contexts
for i, ctx in enumerate(data["supporting_contexts"]):
formatted_data["ctxs"].append({
"id": f"{chunk_id}_sup_{i}",
"title": ctx["title"],
"text": ctx["text"],
"score": f"{random.uniform(0.85, 0.98):.2f}",
"has_answer": True
})
# Add non-supporting contexts
for i, ctx in enumerate(data["non_supporting_contexts"]):
formatted_data["ctxs"].append({
"id": f"{chunk_id}_nonsup_{i}",
"title": ctx["title"],
"text": ctx["text"],
"score": f"{random.uniform(0.35, 0.65):.2f}",
"has_answer": False
})
# Shuffle contexts to mix supporting and non-supporting
random.shuffle(formatted_data["ctxs"])
return formatted_data
except Exception as e:
print(f"Error generating QnA for chunk {chunk_id}: {e}")
return None
def generate_synthetic_data_from_csv(self, csv_path: str, output_path: str,
max_rows: int = None,
include_metadata: bool = True) -> List[Dict]:
"""
Main pipeline: Load CSV → Process each row's context → Split → Generate QnA → Save JSON.
Args:
csv_path: Path to input CSV file
output_path: Path to save output JSON file
max_rows: Maximum number of rows to process (None for all)
include_metadata: Whether to include CSV metadata (ticker, filing, etc.) in output
Returns:
List of generated QnA data
"""
print(f"Loading CSV: {csv_path}")
df = self.load_csv(csv_path)
if max_rows:
df = df.head(max_rows)
print(f"Processing first {max_rows} rows")
synthetic_data = []
for idx, row in df.iterrows():
context_text = str(row['context'])
# Prepare metadata from other columns
metadata = {}
if include_metadata:
for col in df.columns:
if col != 'context':
metadata[col] = row[col]
print(f"\nProcessing row {idx + 1}/{len(df)}")
print(f"Context length: {len(context_text)} characters")
# Split the context into chunks
chunks = self.split_text(context_text)
print(f"Split into {len(chunks)} chunks")
# Generate QnA for each chunk
for chunk_idx, chunk in enumerate(chunks):
chunk_id = f"row{idx}_chunk{chunk_idx}"
print(f" Generating QnA for chunk {chunk_idx + 1}/{len(chunks)}...")
qna_data = self.generate_qna_from_chunk(chunk, metadata, chunk_id)
if qna_data:
synthetic_data.append(qna_data)
# Save to JSON file
print(f"\nSaving synthetic data to: {output_path}")
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(synthetic_data, f, indent=2, ensure_ascii=False)
print(f"Successfully generated {len(synthetic_data)} QnA pairs from {len(df)} rows")
return synthetic_data
# Example usage
if __name__ == "__main__":
# Initialize generator
generator = SyntheticDataGenerator(
api_key=None, # Uses OPENAI_API_KEY env var
model="gpt-4" # or "gpt-3.5-turbo" for faster/cheaper generation
)
# Generate synthetic data from CSV
csv_path = "your_data.csv" # Replace with your CSV path
output_path = "synthetic_data.json"
# Process CSV and generate data
synthetic_data = generator.generate_synthetic_data_from_csv(
csv_path=csv_path,
output_path=output_path,
max_rows=5, # Process first 5 rows for testing; set to None for all rows
include_metadata=True # Include ticker, filing, and other columns in output
)
# Print sample
if synthetic_data:
print("\n" + "="*50)
print("Sample generated data:")
print("="*50)
print(json.dumps(synthetic_data[0], indent=2))
import json
import pandas as pd
from langchain.text_splitter import RecursiveCharacterTextSplitter
from openai import OpenAI
import os
from typing import List, Dict
import random
class SyntheticDataGenerator:
def __init__(self, api_key: str = None, model: str = "gpt-4"):
"""
Initialize the synthetic data generator.
Args:
api_key: OpenAI API key (if None, reads from OPENAI_API_KEY env var)
model: LLM model to use for generation
"""
self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
self.model = model
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
def load_csv(self, csv_path: str) -> pd.DataFrame:
"""
Load CSV file into DataFrame.
Args:
csv_path: Path to the CSV file
Returns:
DataFrame containing the CSV data
"""
df = pd.read_csv(csv_path)
# Validate required columns
required_columns = ['context', 'question', 'answer']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise ValueError(f"CSV missing required columns: {missing_columns}")
print(f"Loaded CSV with {len(df)} rows")
print(f"Columns: {list(df.columns)}")
return df
def split_text(self, text: str) -> List[str]:
"""
Split text into chunks using RecursiveCharacterTextSplitter.
Args:
text: Input text to split
Returns:
List of text chunks
"""
chunks = self.text_splitter.split_text(text)
return chunks
def generate_contexts_from_chunk(self, chunk: str, question: str, answer: str,
row_metadata: Dict, chunk_id: str) -> Dict:
"""
Generate supporting and non-supporting contexts for existing QnA pair.
Args:
chunk: Text chunk to generate contexts from
question: Original question from CSV
answer: Original answer from CSV
row_metadata: Metadata from the CSV row (ticker, filing, etc.)
chunk_id: Unique ID for this chunk
Returns:
Dictionary containing question, answers, and contexts
"""
prompt = f"""Given the following question, answer, and source text, generate supporting and non-supporting context documents.
Question: {question}
Answer: {answer}
Source Text:
{chunk}
Generate a JSON object with this EXACT structure:
{{
"supporting_contexts": [
{{
"title": "Relevant title",
"text": "Text snippet that contains information supporting the answer (extract or paraphrase from source)"
}},
{{
"title": "Another relevant title",
"text": "Another text snippet that supports the answer"
}},
{{
"title": "Third relevant title",
"text": "Third text snippet that supports the answer"
}}
],
"non_supporting_contexts": [
{{
"title": "Related but different topic title",
"text": "Text about a related but different topic that doesn't contain the answer"
}},
{{
"title": "Another unrelated title",
"text": "Another text that doesn't contain the answer"
}},
{{
"title": "Third unrelated title",
"text": "Third text that doesn't answer the question"
}},
{{
"title": "Fourth unrelated title",
"text": "Fourth text that doesn't contain the answer"
}}
]
}}
Requirements:
- Generate exactly 3 supporting contexts (has_answer: True) based on the source text
- Generate exactly 4 non-supporting contexts (has_answer: False) that are plausible but don't answer the question
- Supporting contexts should contain relevant information from the source text
- Non-supporting contexts should be realistic but not helpful for answering the question
- All contexts should seem like real document excerpts
Return ONLY the JSON object, no additional text."""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful assistant that generates structured synthetic training data for question-answering systems. Always respond with valid JSON only."},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=2000
)
response_text = response.choices[0].message.content.strip()
# Try to extract JSON if wrapped in markdown code blocks
if response_text.startswith("```"):
response_text = response_text.split("```")[1]
if response_text.startswith("json"):
response_text = response_text[4:]
response_text = response_text.strip()
data = json.loads(response_text)
# Format the data to match the desired structure
# Use original question and answer from CSV
formatted_data = {
"question": question,
"answers": [answer],
"ctxs": []
}
# Add metadata from CSV row if available
if row_metadata:
formatted_data["metadata"] = row_metadata
# Add supporting contexts
for i, ctx in enumerate(data["supporting_contexts"]):
formatted_data["ctxs"].append({
"id": f"{chunk_id}_sup_{i}",
"title": ctx["title"],
"text": ctx["text"],
"score": f"{random.uniform(0.85, 0.98):.2f}",
"has_answer": True
})
# Add non-supporting contexts
for i, ctx in enumerate(data["non_supporting_contexts"]):
formatted_data["ctxs"].append({
"id": f"{chunk_id}_nonsup_{i}",
"title": ctx["title"],
"text": ctx["text"],
"score": f"{random.uniform(0.35, 0.65):.2f}",
"has_answer": False
})
# Shuffle contexts to mix supporting and non-supporting
random.shuffle(formatted_data["ctxs"])
return formatted_data
except Exception as e:
print(f"Error generating QnA for chunk {chunk_id}: {e}")
return None
def generate_synthetic_data_from_csv(self, csv_path: str, output_path: str,
max_rows: int = None,
include_metadata: bool = True) -> List[Dict]:
"""
Main pipeline: Load CSV → Process each row's context → Split → Generate QnA → Save JSON.
Args:
csv_path: Path to input CSV file
output_path: Path to save output JSON file
max_rows: Maximum number of rows to process (None for all)
include_metadata: Whether to include CSV metadata (ticker, filing, etc.) in output
Returns:
List of generated QnA data
"""
print(f"Loading CSV: {csv_path}")
df = self.load_csv(csv_path)
if max_rows:
df = df.head(max_rows)
print(f"Processing first {max_rows} rows")
synthetic_data = []
for idx, row in df.iterrows():
context_text = str(row['context'])
original_question = str(row.get('question', ''))
original_answer = str(row.get('answer', ''))
# Validate that question and answer exist
if not original_question or not original_answer:
print(f"Warning: Row {idx} missing question or answer, skipping...")
continue
# Prepare metadata from other columns (excluding question, answer, context)
metadata = {}
if include_metadata:
for col in df.columns:
if col not in ['context', 'question', 'answer']:
metadata[col] = row[col]
print(f"\nProcessing row {idx + 1}/{len(df)}")
print(f"Context length: {len(context_text)} characters")
# Split the context into chunks
chunks = self.split_text(context_text)
print(f"Split into {len(chunks)} chunks")
# Generate contexts for each chunk
for chunk_idx, chunk in enumerate(chunks):
chunk_id = f"row{idx}_chunk{chunk_idx}"
print(f" Generating contexts for chunk {chunk_idx + 1}/{len(chunks)}...")
qna_data = self.generate_contexts_from_chunk(
chunk, original_question, original_answer, metadata, chunk_id
)
if qna_data:
synthetic_data.append(qna_data)
# Save to JSON file
print(f"\nSaving synthetic data to: {output_path}")
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(synthetic_data, f, indent=2, ensure_ascii=False)
print(f"Successfully generated {len(synthetic_data)} QnA pairs from {len(df)} rows")
return synthetic_data
# Example usage
if __name__ == "__main__":
# Initialize generator
generator = SyntheticDataGenerator(
api_key=None, # Uses OPENAI_API_KEY env var
model="gpt-4" # or "gpt-3.5-turbo" for faster/cheaper generation
)
# Generate synthetic data from CSV
csv_path = "your_data.csv" # Replace with your CSV path
output_path = "synthetic_data.json"
# Process CSV and generate data
synthetic_data = generator.generate_synthetic_data_from_csv(
csv_path=csv_path,
output_path=output_path,
max_rows=5, # Process first 5 rows for testing; set to None for all rows
include_metadata=True # Include ticker, filing, and other columns in output
)
# Print sample
if synthetic_data:
print("\n" + "="*50)
print("Sample generated data:")
print("="*50)
print(json.dumps(synthetic_data[0], indent=2))
# !pip install transformers datasets peft trl bitsandbytes accelerate pandas -q
import torch
import pandas as pd
import random
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Mxfp4Config
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTTrainer
# --- 1. Data Preparation with Instruction Tuning ---
# This is an enhancement. Using multiple prompt templates makes the model more robust.
def create_instructional_prompt(query, document, label=None):
"""Creates a formatted prompt from a template for the reranking task."""
# Truncate document to a manageable size for training/inference
max_doc_tokens = 512
doc_tokens = document.split()
truncated_doc = " ".join(doc_tokens[:max_doc_tokens])
templates = [
f"Query: {query}\n\nDocument: {truncated_doc}\n\nIs this document relevant to the query? Answer:",
f"Assess the relevance of the following document for the query provided.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nRelevance:",
f"Given the query, determine if the document is relevant.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nDecision:",
]
# Choose a random template for each example to improve model generalization
prompt = random.choice(templates)
if label is not None:
return f"{prompt} {label}"
return prompt
def prepare_reranking_dataset_from_json(json_data):
"""
Processes the proposed rich JSON structure into pointwise training instances.
This is the ideal approach: a well-structured source format unrolled into
effective training examples.
Args:
json_data (list): A list of dictionaries, with each dict matching your proposed structure.
Returns:
datasets.Dataset: A Hugging Face Dataset ready for SFTTrainer.
"""
training_examples = []
for item in json_data:
query = item["question"]
for ctx in item["ctxs"]:
document = ctx["text"]
# The 'has_answer' boolean is our relevance signal.
label = "Relevant" if ctx["has_answer"] else "Not Relevant"
# We reuse the robust prompt creation function from before.
# This unrolls the listwise data into pointwise examples.
full_text = create_instructional_prompt(query, document, label)
training_examples.append({"text": full_text})
return Dataset.from_list(training_examples)
def create_instructional_prompt_structured(query, document, label=None):
"""Creates a formatted prompt from a template for the reranking task."""
max_doc_tokens = 512
doc_tokens = document.split()
truncated_doc = " ".join(doc_tokens[:max_doc_tokens])
templates = [
f"Query: {query}\n\nDocument: {truncated_doc}\n\nIs this document relevant to the query? Answer:",
f"Assess the relevance of the following document for the query provided.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nRelevance:",
f"Given the query, determine if the document is relevant.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nDecision:",
]
prompt = random.choice(templates)
if label is not None:
return f"{prompt} {label}"
return prompt
# --- Example of your data structure ---
# Assume this is loaded from a JSON file, e.g., with json.load()
source_data = [
{
"question": "what is the color of the sky?",
"answers": ["The sky is blue due to Rayleigh scattering."],
"ctxs": [
{
"id": "doc1",
"title": "Atmospheric Optics",
"text": "The sky is blue and vast. This phenomenon is caused by the scattering of sunlight by the Earth's atmosphere.",
"score": "0.88",
"has_answer": True
},
{
"id": "doc2",
"title": "Photosynthesis",
"text": "Photosynthesis in plants converts light to energy, often using green chlorophyll.",
"score": "0.45",
"has_answer": False
}
]
},
{
"question": "tourism in Paris",
"answers": ["The Eiffel Tower is a famous landmark in Paris."],
"ctxs": [
{
"id": "doc3",
"title": "Eiffel Tower",
"text": "Paris, the capital of France, is known for the Eiffel Tower, a popular tourist destination.",
"score": "0.92",
"has_answer": True
}
]
}
]
print("Preparing dataset from new JSON structure...")
train_dataset = prepare_reranking_dataset_from_json(source_data)
print(f"Dataset prepared. Total pointwise examples: {len(train_dataset)}")
print(f"Example data point:\n{train_dataset[0]}")
def prepare_reranking_dataset(queries, corpus, qrels):
"""
Prepares an instruction-tuned dataset for pointwise reranking.
"""
query_map = dict(zip(queries['id'], queries['query_text']))
corpus_map = dict(zip(corpus['id'], corpus['text']))
training_examples = []
for _, row in qrels.iterrows():
query_id, corpus_id, score = row['query_id'], row['corpus_id'], row['score']
if query_id in query_map and corpus_id in corpus_map:
query_text = query_map[query_id]
doc_text = corpus_map[corpus_id]
label = "Relevant" if score == 1 else "Not Relevant"
# This pointwise format avoids context length and positional bias issues
# found in listwise approaches.
full_text = create_instructional_prompt(query_text, doc_text, label)
training_examples.append({"text": full_text})
return Dataset.from_list(training_examples)
# --- Assume your data is loaded into pandas DataFrames ---
corpus_subset = pd.DataFrame({'id': ['doc1', 'doc2', 'doc3'], 'text': ['The sky is blue and vast.', 'Photosynthesis in plants converts light to energy.', 'Paris, the capital of France, is known for the Eiffel Tower.']})
queries_subset = pd.DataFrame({'id': ['q1', 'q2'], 'query_text': ['what is the color of the sky?', 'tourism in Paris']})
qrels_subset = pd.DataFrame({'query_id': ['q1', 'q1', 'q2', 'q2'], 'corpus_id': ['doc1', 'doc2', 'doc1', 'doc3'], 'score': [1, 0, 0, 1]})
print("Preparing dataset...")
full_dataset = prepare_reranking_dataset(queries_subset, corpus_subset, qrels_subset)
train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"]
print(f"Example data point:\n{train_dataset[0]['text']}")
# --- 2. Model and Tokenizer Setup ---
model_name = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'additional_special_tokens': ['Relevant', 'Not Relevant']})
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
quant_config = Mxfp4Config(dequantize=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, quantization_config=quant_config,
device_map="auto", use_cache=False
)
model.resize_token_embeddings(len(tokenizer))
peft_config = LoraConfig(
r=16, lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# --- 3. Training ---
training_args = TrainingArguments(
output_dir="gpt-oss-20b-reranker-v2",
per_device_train_batch_size=1, gradient_accumulation_steps=8,
num_train_epochs=1, learning_rate=1e-5, logging_steps=10,
bf16=True, save_strategy="epoch", report_to="none"
)
trainer = SFTTrainer(
model=model, args=training_args, train_dataset=train_dataset,
tokenizer=tokenizer, dataset_text_field="text", max_seq_length=1024
)
trainer.train()
trainer.save_model()
# --- 4. BATCHED Inference for Reranking ---
# This is the production-ready, highly efficient way to do inference.
adapter_path = "gpt-oss-20b-reranker-v2"
base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload()
model.eval()
relevant_token_id = tokenizer.convert_tokens_to_ids("Relevant")
def rerank_documents_batched(query_text, doc_ids, corpus_map, model, tokenizer, batch_size=4):
"""
Reranks documents in batches for a given query to maximize GPU utilization.
"""
model.eval()
scores = {}
# Prepare all prompts first
prompts = []
for doc_id in doc_ids:
doc_text = corpus_map.get(doc_id, "")
if not doc_text:
scores[doc_id] = -float('inf') # Penalize missing docs heavily
else:
# Use the same prompt creation logic as in training, but without the label
prompts.append({'doc_id': doc_id, 'prompt': create_instructional_prompt(query_text, doc_text)})
with torch.no_grad():
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i+batch_size]
batch_prompts = [item['prompt'] for item in batch]
# Tokenize the batch of prompts
inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(model.device)
# Get logits
outputs = model(**inputs)
logits = outputs.logits
# Get the logits for the last token of each sequence in the batch
last_token_logits = logits[:, -1, :]
# Extract the "Relevant" token's logit for each item in the batch
relevance_scores = last_token_logits[:, relevant_token_id].cpu().tolist()
for j, item in enumerate(batch):
scores[item['doc_id']] = relevance_scores[j]
# Sort doc IDs based on the collected scores
reranked_doc_ids = sorted(scores, key=scores.get, reverse=True)
return [{'doc_id': doc_id, 'score': scores[doc_id]} for doc_id in reranked_doc_ids]
# --- Example Inference Usage ---
top_ranked_docs_subset = pd.DataFrame({'query_id': ['q2'], 'corpus_ids': [['doc1', 'doc2', 'doc3']]})
corpus_map = dict(zip(corpus_subset['id'], corpus_subset['text']))
query_id_to_test = 'q2'
query_text_to_test = queries_subset[queries_subset['id'] == query_id_to_test]['query_text'].iloc[0]
candidate_docs = top_ranked_docs_subset[top_ranked_docs_subset['query_id'] == query_id_to_test]['corpus_ids'].iloc[0]
print("\n--- Batched Inference Example ---")
print(f"Query: '{query_text_to_test}'")
print(f"Initial candidate docs: {candidate_docs}")
reranked_results = rerank_documents_batched(query_text_to_test, candidate_docs, corpus_map, model, tokenizer, batch_size=4)
print("\nReranked documents (higher score is better):")
for result in reranked_results:
print(f" Doc ID: {result['doc_id']}, Score: {result['score']:.4f}")
# !pip install transformers datasets peft trl bitsandbytes accelerate pandas -q
import torch
import json
import random
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Mxfp4Config
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTTrainer
# --- 1. Data Preparation with Instruction Tuning ---
def create_instructional_prompt(query, document, label=None):
"""Creates a formatted prompt from a template for the reranking task."""
# Truncate document to a manageable size for training/inference
max_doc_tokens = 512
doc_tokens = document.split()
truncated_doc = " ".join(doc_tokens[:max_doc_tokens])
templates = [
f"Query: {query}\n\nDocument: {truncated_doc}\n\nIs this document relevant to the query? Answer:",
f"Assess the relevance of the following document for the query provided.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nRelevance:",
f"Given the query, determine if the document is relevant.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nDecision:",
]
# Choose a random template for each example to improve model generalization
prompt = random.choice(templates)
if label is not None:
return f"{prompt} {label}"
return prompt
def prepare_reranking_dataset_from_json(json_path):
"""
Processes JSON file with pre-labeled supporting and non-supporting contexts.
Expected JSON structure (list of objects):
[
{
"question": "A specific question",
"answers": ["A detailed answer"],
"supporting_contexts": [
{"title": "...", "text": "..."},
...
],
"non_supporting_contexts": [
{"title": "...", "text": "..."},
...
]
},
...
]
Args:
json_path (str): Path to the JSON file
Returns:
datasets.Dataset: A Hugging Face Dataset ready for SFTTrainer.
"""
# Load the JSON file
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Handle both single object and list of objects
if isinstance(data, dict):
data = [data]
training_examples = []
print(f"Processing {len(data)} questions from JSON...")
for item in data:
question = item.get("question", "")
supporting_contexts = item.get("supporting_contexts", [])
non_supporting_contexts = item.get("non_supporting_contexts", [])
if not question.strip():
continue
# Create positive examples from supporting contexts
for ctx in supporting_contexts:
text = ctx.get("text", "")
if text.strip():
full_text = create_instructional_prompt(question, text, label="Relevant")
training_examples.append({"text": full_text})
# Create negative examples from non-supporting contexts
for ctx in non_supporting_contexts:
text = ctx.get("text", "")
if text.strip():
full_text = create_instructional_prompt(question, text, label="Not Relevant")
training_examples.append({"text": full_text})
# Shuffle the combined dataset
random.shuffle(training_examples)
print(f"Created {len(training_examples)} total training examples")
# Count positives and negatives
positives = sum(1 for ex in training_examples if "Relevant" in ex["text"] and "Not Relevant" not in ex["text"])
negatives = len(training_examples) - positives
print(f" - Positive examples (Relevant): {positives}")
print(f" - Negative examples (Not Relevant): {negatives}")
print(f" - Ratio (Pos:Neg): 1:{negatives/positives:.2f}" if positives > 0 else "")
return Dataset.from_list(training_examples)
# --- Load and prepare dataset from JSON ---
json_file_path = "your_training_data.json" # Update this path
print("Preparing dataset from JSON...")
train_dataset = prepare_reranking_dataset_from_json(json_file_path)
print(f"\nDataset prepared. Total examples: {len(train_dataset)}")
print(f"\n--- Example data point 1 ---\n{train_dataset[0]['text']}")
print(f"\n--- Example data point 2 ---\n{train_dataset[1]['text']}")
# Split into train and validation
train_val_split = train_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
print(f"\nTraining examples: {len(train_dataset)}")
print(f"Validation examples: {len(eval_dataset)}")
# --- 2. Model and Tokenizer Setup ---
model_name = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'additional_special_tokens': ['Relevant', 'Not Relevant']})
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
quant_config = Mxfp4Config(dequantize=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
quantization_config=quant_config,
device_map="auto",
use_cache=False
)
model.resize_token_embeddings(len(tokenizer))
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# --- 3. Training ---
training_args = TrainingArguments(
output_dir="gpt-oss-20b-reranker-json",
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_train_epochs=1,
learning_rate=1e-5,
logging_steps=10,
bf16=True,
save_strategy="epoch",
evaluation_strategy="steps",
eval_steps=100,
save_total_limit=2,
report_to="none"
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
dataset_text_field="text",
max_seq_length=1024
)
trainer.train()
trainer.save_model()
# --- 4. BATCHED Inference for Reranking ---
adapter_path = "gpt-oss-20b-reranker-json"
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload()
model.eval()
relevant_token_id = tokenizer.convert_tokens_to_ids("Relevant")
def rerank_documents_batched(query_text, documents, model, tokenizer, batch_size=4):
"""
Reranks documents in batches for a given query to maximize GPU utilization.
Args:
query_text (str): The search query
documents (list): List of document texts or dicts with 'text' key to rerank
model: The fine-tuned model
tokenizer: The tokenizer
batch_size (int): Batch size for inference
Returns:
list: Documents ranked by relevance score
"""
model.eval()
scores = {}
# Prepare all prompts first
prompts = []
for idx, doc in enumerate(documents):
# Handle both string documents and dict documents
if isinstance(doc, dict):
doc_text = doc.get('text', '')
doc_title = doc.get('title', f'Document {idx+1}')
else:
doc_text = doc
doc_title = f'Document {idx+1}'
if not doc_text or not doc_text.strip():
scores[idx] = -float('inf') # Penalize empty docs
else:
prompts.append({
'idx': idx,
'title': doc_title,
'text': doc_text,
'prompt': create_instructional_prompt(query_text, doc_text)
})
with torch.no_grad():
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i+batch_size]
batch_prompts = [item['prompt'] for item in batch]
# Tokenize the batch of prompts
inputs = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
).to(model.device)
# Get logits
outputs = model(**inputs)
logits = outputs.logits
# Get the logits for the last token of each sequence
last_token_logits = logits[:, -1, :]
# Extract the "Relevant" token's logit for each item
relevance_scores = last_token_logits[:, relevant_token_id].cpu().tolist()
for j, item in enumerate(batch):
scores[item['idx']] = relevance_scores[j]
# Sort documents by score
ranked_indices = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
return [
{
'title': prompts[idx]['title'] if idx < len(prompts) else f'Document {idx+1}',
'text': prompts[idx]['text'] if idx < len(prompts) else documents[idx],
'score': scores[idx],
'rank': rank + 1
}
for rank, idx in enumerate(ranked_indices)
]
# --- Example Inference Usage ---
print("\n--- Batched Inference Example ---")
# Example: Load test data from JSON format
test_query = "What area did NVIDIA initially focus on before expanding to other fields?"
test_documents = [
{
"title": "NVIDIA Company History",
"text": "Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields."
},
{
"title": "GPU Deep Learning Applications",
"text": "Some of the most recent applications of GPU-powered deep learning include recommendation systems, large language models, and generative AI."
},
{
"title": "Consumer Electronics",
"text": "The company manufactures various consumer electronics products including smartphones and tablets."
},
{
"title": "NVIDIA Origins",
"text": "NVIDIA initially focused on PC graphics before expanding to AI, data centers, and autonomous vehicles."
}
]
print(f"Query: '{test_query}'")
print(f"Number of candidate documents: {len(test_documents)}\n")
reranked_results = rerank_documents_batched(
test_query,
test_documents,
model,
tokenizer,
batch_size=4
)
print("Reranked documents (higher score is better):")
for result in reranked_results:
doc_preview = result['text'][:100] + "..." if len(result['text']) > 100 else result['text']
print(f"\n{'='*80}")
print(f"Rank {result['rank']}: Score: {result['score']:.4f}")
print(f"Title: {result['title']}")
print(f"Text: {doc_preview}")
# --- Additional Helper: Load and test with actual JSON structure ---
def test_with_json_file(json_path, model, tokenizer, question_index=0, top_k=5):
"""
Test the reranker with actual data from your JSON file.
Args:
json_path: Path to your JSON file
model: Trained model
tokenizer: Tokenizer
question_index: Which question to test (default first one)
top_k: Number of top results to show
"""
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, dict):
data = [data]
if question_index >= len(data):
print(f"Question index {question_index} out of range. Only {len(data)} questions available.")
return
item = data[question_index]
query = item["question"]
# Combine all contexts for reranking
all_contexts = []
all_contexts.extend([{"text": ctx["text"], "title": ctx["title"], "label": "SUPPORTING"}
for ctx in item.get("supporting_contexts", [])])
all_contexts.extend([{"text": ctx["text"], "title": ctx["title"], "label": "NON-SUPPORTING"}
for ctx in item.get("non_supporting_contexts", [])])
print(f"\n{'='*80}")
print(f"TESTING WITH REAL DATA")
print(f"{'='*80}")
print(f"Question: {query}")
print(f"Total contexts to rerank: {len(all_contexts)}")
print(f" - Supporting: {len(item.get('supporting_contexts', []))}")
print(f" - Non-supporting: {len(item.get('non_supporting_contexts', []))}")
results = rerank_documents_batched(query, all_contexts, model, tokenizer, batch_size=4)
print(f"\n{'='*80}")
print(f"TOP {top_k} RERANKED RESULTS:")
print(f"{'='*80}")
for i, result in enumerate(results[:top_k]):
original_label = all_contexts[i].get("label", "UNKNOWN")
print(f"\nRank {result['rank']}: Score: {result['score']:.4f} | Original Label: {original_label}")
print(f"Title: {result['title']}")
print(f"Text: {result['text'][:150]}...")
# Uncomment to test with your actual JSON file:
# test_with_json_file("your_training_data.json", model, tokenizer, question_index=0, top_k=5)
# !pip install transformers==4.55.0 datasets peft trl bitsandbytes openai-harmony tiktoken accelerate
# Assume data_list is a list of dicts each with a "messages" field (list of messages as above)
# train_dataset = Dataset.from_list([ {"messages": conv["messages"]} for conv in training_conversations ])
# val_dataset = Dataset.from_list([ {"messages": conv["messages"]} for conv in validation_conversations ])
import torch
from transformers import AutoModelForCausalLM, Mxfp4Config
model_name = "openai/gpt-oss-20b"
quant_config = Mxfp4Config(dequantize=True) # dequantize 4-bit weights to float for training
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
quantization_config=quant_config,
device_map="auto", # automatically put layers on available GPU(s)
use_cache=False # disable cache for training (gradient checkpointing will be used)
)
from peft import LoraConfig, get_peft_model
peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["all"], # "all-linear" can be used in latest peft to auto-detect linear layers
# target_parameters=[
# "7.mlp.experts.gate_up_proj",
# "7.mlp.experts.down_proj",
# "15.mlp.experts.gate_up_proj",
# "15.mlp.experts.down_proj",
# "23.mlp.experts.gate_up_proj",
# "23.mlp.experts.down_proj",
# ],
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
from trl import SFTTrainer, SFTConfig
from transformers import TrainingArguments
# Define training hyperparameters
training_args = TrainingArguments(
output_dir="gpt-oss-20b-healthcare-checkpoints",
per_device_train_batch_size=1, # we use batch size 1 per GPU to maximize sequence length usage
gradient_accumulation_steps=4, # accumulate to simulate batch size 4 if needed
num_train_epochs=2,
learning_rate=2e-4,
logging_steps=50,
bf16=True, # use bfloat16 mixed precision
save_strategy="epoch",
report_to="none" # (or "wandb" if using Weights & Biases for logging)
)
# training_args = SFTConfig(
# learning_rate=2e-4,
# gradient_checkpointing=True,
# num_train_epochs=1,
# logging_steps=10,
# bf16=True,
# per_device_train_batch_size=8,
# per_device_eval_batch_size=8,
# gradient_accumulation_steps=2,
# max_length=2048,
# warmup_ratio=0.03,
# eval_strategy="steps",
# eval_steps=10,
# lr_scheduler_type="cosine_with_min_lr",
# lr_scheduler_kwargs={"min_lr_rate": 0.1},
# output_dir=SAVED_MODEL_ID,
# report_to="tensorboard",
# push_to_hub=True,
# )
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer
)
trainer.train()
trainer.save_model() # saves to output_dir
# inference with FT model
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# Load base model (back in 4-bit quantized form for inference to save memory, if desired)
model_name = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # or "auto"
device_map="auto"
)
# Load the LoRA adapter
adapter_path = "gpt-oss-20b-healthcare-checkpoints" # where we saved our LoRA fine-tuned model
model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload() # merges LoRA into base_model and returns a standard model
prompt_messages = [
{"role": "developer", "content": "You are a compassionate doctor AI assistant. Answer medical questions with accurate, helpful information."},
{"role": "user", "content": "What are the symptoms of diabetes?"}
]
inputs = tokenizer.apply_chat_template(prompt_messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
output_ids = model.generate(**inputs, max_new_tokens=200, temperature=0.7)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(response)
## streaming the output
from transformers import TextIteratorStreamer
from threading import Thread
# Prepare the streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# We will generate in a background thread to not block the main thread
gen_kwargs = dict(**inputs, max_new_tokens=200, temperature=0.7, streamer=streamer)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
print("Assistant:", end=" ", flush=True)
# Read the generated text as it streams
for token_text in streamer:
print(token_text, end="", flush=True)
@dhruvilp
Copy link
Author

ml.p4de.24xlarge - for 20B
ml.p5en.48xlarge - for 120B

@dhruvilp
Copy link
Author

dhruvilp commented Oct 12, 2025

Dataset Format:

[
    {
        "question": "...",
        "answers": ["...", "...", ...],
        "ctxs": [
            {
                "id": "...",         // Passage ID from database TSV file
                "title": "...",      // Passage title
                "text": "...",       // Passage full text
                "score": "...",      // Retriever score
                "has_answer": true|false  // Whether the passage contains the answer
            }
        ]
    }
]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment