Last active
October 13, 2025 03:22
-
-
Save dhruvilp/7160cfcfa1a14e3e2c8aa23345673863 to your computer and use it in GitHub Desktop.
gpt-oss-20b-ft-lora-sample-code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import pandas as pd | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from openai import OpenAI | |
| import os | |
| from typing import List, Dict | |
| import random | |
| class SyntheticDataGenerator: | |
| def __init__(self, api_key: str = None, model: str = "gpt-4"): | |
| """ | |
| Initialize the synthetic data generator. | |
| Args: | |
| api_key: OpenAI API key (if None, reads from OPENAI_API_KEY env var) | |
| model: LLM model to use for generation | |
| """ | |
| self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY")) | |
| self.model = model | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| ) | |
| def load_csv(self, csv_path: str) -> pd.DataFrame: | |
| """ | |
| Load CSV file into DataFrame. | |
| Args: | |
| csv_path: Path to the CSV file | |
| Returns: | |
| DataFrame containing the CSV data | |
| """ | |
| df = pd.read_csv(csv_path) | |
| # Validate required columns | |
| required_columns = ['context'] | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| raise ValueError(f"CSV missing required columns: {missing_columns}") | |
| print(f"Loaded CSV with {len(df)} rows") | |
| print(f"Columns: {list(df.columns)}") | |
| return df | |
| def split_text(self, text: str) -> List[str]: | |
| """ | |
| Split text into chunks using RecursiveCharacterTextSplitter. | |
| Args: | |
| text: Input text to split | |
| Returns: | |
| List of text chunks | |
| """ | |
| chunks = self.text_splitter.split_text(text) | |
| return chunks | |
| def generate_qna_from_chunk(self, chunk: str, row_metadata: Dict, chunk_id: str) -> Dict: | |
| """ | |
| Generate QnA pair with contexts from a text chunk using LLM. | |
| Args: | |
| chunk: Text chunk to generate QnA from | |
| row_metadata: Metadata from the CSV row (ticker, filing, etc.) | |
| chunk_id: Unique ID for this chunk | |
| Returns: | |
| Dictionary containing question, answers, and contexts | |
| """ | |
| prompt = f"""Based on the following text, generate a synthetic question-answer pair with supporting and non-supporting context documents. | |
| Text: | |
| {chunk} | |
| Generate a JSON object with this EXACT structure: | |
| {{ | |
| "question": "A specific question based on the text", | |
| "answers": ["A detailed answer to the question"], | |
| "supporting_contexts": [ | |
| {{ | |
| "title": "Relevant title", | |
| "text": "Text snippet that contains the answer (extract or paraphrase from source)" | |
| }}, | |
| {{ | |
| "title": "Another relevant title", | |
| "text": "Another text snippet that supports the answer" | |
| }}, | |
| {{ | |
| "title": "Third relevant title", | |
| "text": "Third text snippet that supports the answer" | |
| }} | |
| ], | |
| "non_supporting_contexts": [ | |
| {{ | |
| "title": "Related but different topic title", | |
| "text": "Text about a related but different topic that doesn't answer the question" | |
| }}, | |
| {{ | |
| "title": "Another unrelated title", | |
| "text": "Another text that doesn't contain the answer" | |
| }}, | |
| {{ | |
| "title": "Third unrelated title", | |
| "text": "Third text that doesn't answer the question" | |
| }}, | |
| {{ | |
| "title": "Fourth unrelated title", | |
| "text": "Fourth text that doesn't contain the answer" | |
| }} | |
| ] | |
| }} | |
| Requirements: | |
| - Question should be specific and answerable from the text | |
| - Generate exactly 3 supporting contexts (has_answer: True) | |
| - Generate exactly 4 non-supporting contexts (has_answer: False) | |
| - Non-supporting contexts should be plausible but not contain the answer | |
| - All contexts should seem realistic | |
| Return ONLY the JSON object, no additional text.""" | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that generates structured synthetic training data for question-answering systems. Always respond with valid JSON only."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.7, | |
| max_tokens=2000 | |
| ) | |
| response_text = response.choices[0].message.content.strip() | |
| # Try to extract JSON if wrapped in markdown code blocks | |
| if response_text.startswith("```"): | |
| response_text = response_text.split("```")[1] | |
| if response_text.startswith("json"): | |
| response_text = response_text[4:] | |
| response_text = response_text.strip() | |
| data = json.loads(response_text) | |
| # Format the data to match the desired structure | |
| formatted_data = { | |
| "question": data["question"], | |
| "answers": data["answers"], | |
| "ctxs": [] | |
| } | |
| # Add metadata from CSV row if available | |
| if row_metadata: | |
| formatted_data["metadata"] = row_metadata | |
| # Add supporting contexts | |
| for i, ctx in enumerate(data["supporting_contexts"]): | |
| formatted_data["ctxs"].append({ | |
| "id": f"{chunk_id}_sup_{i}", | |
| "title": ctx["title"], | |
| "text": ctx["text"], | |
| "score": f"{random.uniform(0.85, 0.98):.2f}", | |
| "has_answer": True | |
| }) | |
| # Add non-supporting contexts | |
| for i, ctx in enumerate(data["non_supporting_contexts"]): | |
| formatted_data["ctxs"].append({ | |
| "id": f"{chunk_id}_nonsup_{i}", | |
| "title": ctx["title"], | |
| "text": ctx["text"], | |
| "score": f"{random.uniform(0.35, 0.65):.2f}", | |
| "has_answer": False | |
| }) | |
| # Shuffle contexts to mix supporting and non-supporting | |
| random.shuffle(formatted_data["ctxs"]) | |
| return formatted_data | |
| except Exception as e: | |
| print(f"Error generating QnA for chunk {chunk_id}: {e}") | |
| return None | |
| def generate_synthetic_data_from_csv(self, csv_path: str, output_path: str, | |
| max_rows: int = None, | |
| include_metadata: bool = True) -> List[Dict]: | |
| """ | |
| Main pipeline: Load CSV → Process each row's context → Split → Generate QnA → Save JSON. | |
| Args: | |
| csv_path: Path to input CSV file | |
| output_path: Path to save output JSON file | |
| max_rows: Maximum number of rows to process (None for all) | |
| include_metadata: Whether to include CSV metadata (ticker, filing, etc.) in output | |
| Returns: | |
| List of generated QnA data | |
| """ | |
| print(f"Loading CSV: {csv_path}") | |
| df = self.load_csv(csv_path) | |
| if max_rows: | |
| df = df.head(max_rows) | |
| print(f"Processing first {max_rows} rows") | |
| synthetic_data = [] | |
| for idx, row in df.iterrows(): | |
| context_text = str(row['context']) | |
| # Prepare metadata from other columns | |
| metadata = {} | |
| if include_metadata: | |
| for col in df.columns: | |
| if col != 'context': | |
| metadata[col] = row[col] | |
| print(f"\nProcessing row {idx + 1}/{len(df)}") | |
| print(f"Context length: {len(context_text)} characters") | |
| # Split the context into chunks | |
| chunks = self.split_text(context_text) | |
| print(f"Split into {len(chunks)} chunks") | |
| # Generate QnA for each chunk | |
| for chunk_idx, chunk in enumerate(chunks): | |
| chunk_id = f"row{idx}_chunk{chunk_idx}" | |
| print(f" Generating QnA for chunk {chunk_idx + 1}/{len(chunks)}...") | |
| qna_data = self.generate_qna_from_chunk(chunk, metadata, chunk_id) | |
| if qna_data: | |
| synthetic_data.append(qna_data) | |
| # Save to JSON file | |
| print(f"\nSaving synthetic data to: {output_path}") | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(synthetic_data, f, indent=2, ensure_ascii=False) | |
| print(f"Successfully generated {len(synthetic_data)} QnA pairs from {len(df)} rows") | |
| return synthetic_data | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Initialize generator | |
| generator = SyntheticDataGenerator( | |
| api_key=None, # Uses OPENAI_API_KEY env var | |
| model="gpt-4" # or "gpt-3.5-turbo" for faster/cheaper generation | |
| ) | |
| # Generate synthetic data from CSV | |
| csv_path = "your_data.csv" # Replace with your CSV path | |
| output_path = "synthetic_data.json" | |
| # Process CSV and generate data | |
| synthetic_data = generator.generate_synthetic_data_from_csv( | |
| csv_path=csv_path, | |
| output_path=output_path, | |
| max_rows=5, # Process first 5 rows for testing; set to None for all rows | |
| include_metadata=True # Include ticker, filing, and other columns in output | |
| ) | |
| # Print sample | |
| if synthetic_data: | |
| print("\n" + "="*50) | |
| print("Sample generated data:") | |
| print("="*50) | |
| print(json.dumps(synthetic_data[0], indent=2)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import pandas as pd | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from openai import OpenAI | |
| import os | |
| from typing import List, Dict | |
| import random | |
| class SyntheticDataGenerator: | |
| def __init__(self, api_key: str = None, model: str = "gpt-4"): | |
| """ | |
| Initialize the synthetic data generator. | |
| Args: | |
| api_key: OpenAI API key (if None, reads from OPENAI_API_KEY env var) | |
| model: LLM model to use for generation | |
| """ | |
| self.client = OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY")) | |
| self.model = model | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| ) | |
| def load_csv(self, csv_path: str) -> pd.DataFrame: | |
| """ | |
| Load CSV file into DataFrame. | |
| Args: | |
| csv_path: Path to the CSV file | |
| Returns: | |
| DataFrame containing the CSV data | |
| """ | |
| df = pd.read_csv(csv_path) | |
| # Validate required columns | |
| required_columns = ['context', 'question', 'answer'] | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| raise ValueError(f"CSV missing required columns: {missing_columns}") | |
| print(f"Loaded CSV with {len(df)} rows") | |
| print(f"Columns: {list(df.columns)}") | |
| return df | |
| def split_text(self, text: str) -> List[str]: | |
| """ | |
| Split text into chunks using RecursiveCharacterTextSplitter. | |
| Args: | |
| text: Input text to split | |
| Returns: | |
| List of text chunks | |
| """ | |
| chunks = self.text_splitter.split_text(text) | |
| return chunks | |
| def generate_contexts_from_chunk(self, chunk: str, question: str, answer: str, | |
| row_metadata: Dict, chunk_id: str) -> Dict: | |
| """ | |
| Generate supporting and non-supporting contexts for existing QnA pair. | |
| Args: | |
| chunk: Text chunk to generate contexts from | |
| question: Original question from CSV | |
| answer: Original answer from CSV | |
| row_metadata: Metadata from the CSV row (ticker, filing, etc.) | |
| chunk_id: Unique ID for this chunk | |
| Returns: | |
| Dictionary containing question, answers, and contexts | |
| """ | |
| prompt = f"""Given the following question, answer, and source text, generate supporting and non-supporting context documents. | |
| Question: {question} | |
| Answer: {answer} | |
| Source Text: | |
| {chunk} | |
| Generate a JSON object with this EXACT structure: | |
| {{ | |
| "supporting_contexts": [ | |
| {{ | |
| "title": "Relevant title", | |
| "text": "Text snippet that contains information supporting the answer (extract or paraphrase from source)" | |
| }}, | |
| {{ | |
| "title": "Another relevant title", | |
| "text": "Another text snippet that supports the answer" | |
| }}, | |
| {{ | |
| "title": "Third relevant title", | |
| "text": "Third text snippet that supports the answer" | |
| }} | |
| ], | |
| "non_supporting_contexts": [ | |
| {{ | |
| "title": "Related but different topic title", | |
| "text": "Text about a related but different topic that doesn't contain the answer" | |
| }}, | |
| {{ | |
| "title": "Another unrelated title", | |
| "text": "Another text that doesn't contain the answer" | |
| }}, | |
| {{ | |
| "title": "Third unrelated title", | |
| "text": "Third text that doesn't answer the question" | |
| }}, | |
| {{ | |
| "title": "Fourth unrelated title", | |
| "text": "Fourth text that doesn't contain the answer" | |
| }} | |
| ] | |
| }} | |
| Requirements: | |
| - Generate exactly 3 supporting contexts (has_answer: True) based on the source text | |
| - Generate exactly 4 non-supporting contexts (has_answer: False) that are plausible but don't answer the question | |
| - Supporting contexts should contain relevant information from the source text | |
| - Non-supporting contexts should be realistic but not helpful for answering the question | |
| - All contexts should seem like real document excerpts | |
| Return ONLY the JSON object, no additional text.""" | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that generates structured synthetic training data for question-answering systems. Always respond with valid JSON only."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.7, | |
| max_tokens=2000 | |
| ) | |
| response_text = response.choices[0].message.content.strip() | |
| # Try to extract JSON if wrapped in markdown code blocks | |
| if response_text.startswith("```"): | |
| response_text = response_text.split("```")[1] | |
| if response_text.startswith("json"): | |
| response_text = response_text[4:] | |
| response_text = response_text.strip() | |
| data = json.loads(response_text) | |
| # Format the data to match the desired structure | |
| # Use original question and answer from CSV | |
| formatted_data = { | |
| "question": question, | |
| "answers": [answer], | |
| "ctxs": [] | |
| } | |
| # Add metadata from CSV row if available | |
| if row_metadata: | |
| formatted_data["metadata"] = row_metadata | |
| # Add supporting contexts | |
| for i, ctx in enumerate(data["supporting_contexts"]): | |
| formatted_data["ctxs"].append({ | |
| "id": f"{chunk_id}_sup_{i}", | |
| "title": ctx["title"], | |
| "text": ctx["text"], | |
| "score": f"{random.uniform(0.85, 0.98):.2f}", | |
| "has_answer": True | |
| }) | |
| # Add non-supporting contexts | |
| for i, ctx in enumerate(data["non_supporting_contexts"]): | |
| formatted_data["ctxs"].append({ | |
| "id": f"{chunk_id}_nonsup_{i}", | |
| "title": ctx["title"], | |
| "text": ctx["text"], | |
| "score": f"{random.uniform(0.35, 0.65):.2f}", | |
| "has_answer": False | |
| }) | |
| # Shuffle contexts to mix supporting and non-supporting | |
| random.shuffle(formatted_data["ctxs"]) | |
| return formatted_data | |
| except Exception as e: | |
| print(f"Error generating QnA for chunk {chunk_id}: {e}") | |
| return None | |
| def generate_synthetic_data_from_csv(self, csv_path: str, output_path: str, | |
| max_rows: int = None, | |
| include_metadata: bool = True) -> List[Dict]: | |
| """ | |
| Main pipeline: Load CSV → Process each row's context → Split → Generate QnA → Save JSON. | |
| Args: | |
| csv_path: Path to input CSV file | |
| output_path: Path to save output JSON file | |
| max_rows: Maximum number of rows to process (None for all) | |
| include_metadata: Whether to include CSV metadata (ticker, filing, etc.) in output | |
| Returns: | |
| List of generated QnA data | |
| """ | |
| print(f"Loading CSV: {csv_path}") | |
| df = self.load_csv(csv_path) | |
| if max_rows: | |
| df = df.head(max_rows) | |
| print(f"Processing first {max_rows} rows") | |
| synthetic_data = [] | |
| for idx, row in df.iterrows(): | |
| context_text = str(row['context']) | |
| original_question = str(row.get('question', '')) | |
| original_answer = str(row.get('answer', '')) | |
| # Validate that question and answer exist | |
| if not original_question or not original_answer: | |
| print(f"Warning: Row {idx} missing question or answer, skipping...") | |
| continue | |
| # Prepare metadata from other columns (excluding question, answer, context) | |
| metadata = {} | |
| if include_metadata: | |
| for col in df.columns: | |
| if col not in ['context', 'question', 'answer']: | |
| metadata[col] = row[col] | |
| print(f"\nProcessing row {idx + 1}/{len(df)}") | |
| print(f"Context length: {len(context_text)} characters") | |
| # Split the context into chunks | |
| chunks = self.split_text(context_text) | |
| print(f"Split into {len(chunks)} chunks") | |
| # Generate contexts for each chunk | |
| for chunk_idx, chunk in enumerate(chunks): | |
| chunk_id = f"row{idx}_chunk{chunk_idx}" | |
| print(f" Generating contexts for chunk {chunk_idx + 1}/{len(chunks)}...") | |
| qna_data = self.generate_contexts_from_chunk( | |
| chunk, original_question, original_answer, metadata, chunk_id | |
| ) | |
| if qna_data: | |
| synthetic_data.append(qna_data) | |
| # Save to JSON file | |
| print(f"\nSaving synthetic data to: {output_path}") | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(synthetic_data, f, indent=2, ensure_ascii=False) | |
| print(f"Successfully generated {len(synthetic_data)} QnA pairs from {len(df)} rows") | |
| return synthetic_data | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Initialize generator | |
| generator = SyntheticDataGenerator( | |
| api_key=None, # Uses OPENAI_API_KEY env var | |
| model="gpt-4" # or "gpt-3.5-turbo" for faster/cheaper generation | |
| ) | |
| # Generate synthetic data from CSV | |
| csv_path = "your_data.csv" # Replace with your CSV path | |
| output_path = "synthetic_data.json" | |
| # Process CSV and generate data | |
| synthetic_data = generator.generate_synthetic_data_from_csv( | |
| csv_path=csv_path, | |
| output_path=output_path, | |
| max_rows=5, # Process first 5 rows for testing; set to None for all rows | |
| include_metadata=True # Include ticker, filing, and other columns in output | |
| ) | |
| # Print sample | |
| if synthetic_data: | |
| print("\n" + "="*50) | |
| print("Sample generated data:") | |
| print("="*50) | |
| print(json.dumps(synthetic_data[0], indent=2)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # !pip install transformers datasets peft trl bitsandbytes accelerate pandas -q | |
| import torch | |
| import pandas as pd | |
| import random | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Mxfp4Config | |
| from peft import LoraConfig, get_peft_model, PeftModel | |
| from trl import SFTTrainer | |
| # --- 1. Data Preparation with Instruction Tuning --- | |
| # This is an enhancement. Using multiple prompt templates makes the model more robust. | |
| def create_instructional_prompt(query, document, label=None): | |
| """Creates a formatted prompt from a template for the reranking task.""" | |
| # Truncate document to a manageable size for training/inference | |
| max_doc_tokens = 512 | |
| doc_tokens = document.split() | |
| truncated_doc = " ".join(doc_tokens[:max_doc_tokens]) | |
| templates = [ | |
| f"Query: {query}\n\nDocument: {truncated_doc}\n\nIs this document relevant to the query? Answer:", | |
| f"Assess the relevance of the following document for the query provided.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nRelevance:", | |
| f"Given the query, determine if the document is relevant.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nDecision:", | |
| ] | |
| # Choose a random template for each example to improve model generalization | |
| prompt = random.choice(templates) | |
| if label is not None: | |
| return f"{prompt} {label}" | |
| return prompt | |
| def prepare_reranking_dataset_from_json(json_data): | |
| """ | |
| Processes the proposed rich JSON structure into pointwise training instances. | |
| This is the ideal approach: a well-structured source format unrolled into | |
| effective training examples. | |
| Args: | |
| json_data (list): A list of dictionaries, with each dict matching your proposed structure. | |
| Returns: | |
| datasets.Dataset: A Hugging Face Dataset ready for SFTTrainer. | |
| """ | |
| training_examples = [] | |
| for item in json_data: | |
| query = item["question"] | |
| for ctx in item["ctxs"]: | |
| document = ctx["text"] | |
| # The 'has_answer' boolean is our relevance signal. | |
| label = "Relevant" if ctx["has_answer"] else "Not Relevant" | |
| # We reuse the robust prompt creation function from before. | |
| # This unrolls the listwise data into pointwise examples. | |
| full_text = create_instructional_prompt(query, document, label) | |
| training_examples.append({"text": full_text}) | |
| return Dataset.from_list(training_examples) | |
| def create_instructional_prompt_structured(query, document, label=None): | |
| """Creates a formatted prompt from a template for the reranking task.""" | |
| max_doc_tokens = 512 | |
| doc_tokens = document.split() | |
| truncated_doc = " ".join(doc_tokens[:max_doc_tokens]) | |
| templates = [ | |
| f"Query: {query}\n\nDocument: {truncated_doc}\n\nIs this document relevant to the query? Answer:", | |
| f"Assess the relevance of the following document for the query provided.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nRelevance:", | |
| f"Given the query, determine if the document is relevant.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nDecision:", | |
| ] | |
| prompt = random.choice(templates) | |
| if label is not None: | |
| return f"{prompt} {label}" | |
| return prompt | |
| # --- Example of your data structure --- | |
| # Assume this is loaded from a JSON file, e.g., with json.load() | |
| source_data = [ | |
| { | |
| "question": "what is the color of the sky?", | |
| "answers": ["The sky is blue due to Rayleigh scattering."], | |
| "ctxs": [ | |
| { | |
| "id": "doc1", | |
| "title": "Atmospheric Optics", | |
| "text": "The sky is blue and vast. This phenomenon is caused by the scattering of sunlight by the Earth's atmosphere.", | |
| "score": "0.88", | |
| "has_answer": True | |
| }, | |
| { | |
| "id": "doc2", | |
| "title": "Photosynthesis", | |
| "text": "Photosynthesis in plants converts light to energy, often using green chlorophyll.", | |
| "score": "0.45", | |
| "has_answer": False | |
| } | |
| ] | |
| }, | |
| { | |
| "question": "tourism in Paris", | |
| "answers": ["The Eiffel Tower is a famous landmark in Paris."], | |
| "ctxs": [ | |
| { | |
| "id": "doc3", | |
| "title": "Eiffel Tower", | |
| "text": "Paris, the capital of France, is known for the Eiffel Tower, a popular tourist destination.", | |
| "score": "0.92", | |
| "has_answer": True | |
| } | |
| ] | |
| } | |
| ] | |
| print("Preparing dataset from new JSON structure...") | |
| train_dataset = prepare_reranking_dataset_from_json(source_data) | |
| print(f"Dataset prepared. Total pointwise examples: {len(train_dataset)}") | |
| print(f"Example data point:\n{train_dataset[0]}") | |
| def prepare_reranking_dataset(queries, corpus, qrels): | |
| """ | |
| Prepares an instruction-tuned dataset for pointwise reranking. | |
| """ | |
| query_map = dict(zip(queries['id'], queries['query_text'])) | |
| corpus_map = dict(zip(corpus['id'], corpus['text'])) | |
| training_examples = [] | |
| for _, row in qrels.iterrows(): | |
| query_id, corpus_id, score = row['query_id'], row['corpus_id'], row['score'] | |
| if query_id in query_map and corpus_id in corpus_map: | |
| query_text = query_map[query_id] | |
| doc_text = corpus_map[corpus_id] | |
| label = "Relevant" if score == 1 else "Not Relevant" | |
| # This pointwise format avoids context length and positional bias issues | |
| # found in listwise approaches. | |
| full_text = create_instructional_prompt(query_text, doc_text, label) | |
| training_examples.append({"text": full_text}) | |
| return Dataset.from_list(training_examples) | |
| # --- Assume your data is loaded into pandas DataFrames --- | |
| corpus_subset = pd.DataFrame({'id': ['doc1', 'doc2', 'doc3'], 'text': ['The sky is blue and vast.', 'Photosynthesis in plants converts light to energy.', 'Paris, the capital of France, is known for the Eiffel Tower.']}) | |
| queries_subset = pd.DataFrame({'id': ['q1', 'q2'], 'query_text': ['what is the color of the sky?', 'tourism in Paris']}) | |
| qrels_subset = pd.DataFrame({'query_id': ['q1', 'q1', 'q2', 'q2'], 'corpus_id': ['doc1', 'doc2', 'doc1', 'doc3'], 'score': [1, 0, 0, 1]}) | |
| print("Preparing dataset...") | |
| full_dataset = prepare_reranking_dataset(queries_subset, corpus_subset, qrels_subset) | |
| train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"] | |
| print(f"Example data point:\n{train_dataset[0]['text']}") | |
| # --- 2. Model and Tokenizer Setup --- | |
| model_name = "openai/gpt-oss-20b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.add_special_tokens({'additional_special_tokens': ['Relevant', 'Not Relevant']}) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| quant_config = Mxfp4Config(dequantize=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, torch_dtype=torch.bfloat16, quantization_config=quant_config, | |
| device_map="auto", use_cache=False | |
| ) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| peft_config = LoraConfig( | |
| r=16, lora_alpha=32, | |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| bias="none", task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| # --- 3. Training --- | |
| training_args = TrainingArguments( | |
| output_dir="gpt-oss-20b-reranker-v2", | |
| per_device_train_batch_size=1, gradient_accumulation_steps=8, | |
| num_train_epochs=1, learning_rate=1e-5, logging_steps=10, | |
| bf16=True, save_strategy="epoch", report_to="none" | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| tokenizer=tokenizer, dataset_text_field="text", max_seq_length=1024 | |
| ) | |
| trainer.train() | |
| trainer.save_model() | |
| # --- 4. BATCHED Inference for Reranking --- | |
| # This is the production-ready, highly efficient way to do inference. | |
| adapter_path = "gpt-oss-20b-reranker-v2" | |
| base_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") | |
| model = PeftModel.from_pretrained(base_model, adapter_path) | |
| model = model.merge_and_unload() | |
| model.eval() | |
| relevant_token_id = tokenizer.convert_tokens_to_ids("Relevant") | |
| def rerank_documents_batched(query_text, doc_ids, corpus_map, model, tokenizer, batch_size=4): | |
| """ | |
| Reranks documents in batches for a given query to maximize GPU utilization. | |
| """ | |
| model.eval() | |
| scores = {} | |
| # Prepare all prompts first | |
| prompts = [] | |
| for doc_id in doc_ids: | |
| doc_text = corpus_map.get(doc_id, "") | |
| if not doc_text: | |
| scores[doc_id] = -float('inf') # Penalize missing docs heavily | |
| else: | |
| # Use the same prompt creation logic as in training, but without the label | |
| prompts.append({'doc_id': doc_id, 'prompt': create_instructional_prompt(query_text, doc_text)}) | |
| with torch.no_grad(): | |
| for i in range(0, len(prompts), batch_size): | |
| batch = prompts[i:i+batch_size] | |
| batch_prompts = [item['prompt'] for item in batch] | |
| # Tokenize the batch of prompts | |
| inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(model.device) | |
| # Get logits | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get the logits for the last token of each sequence in the batch | |
| last_token_logits = logits[:, -1, :] | |
| # Extract the "Relevant" token's logit for each item in the batch | |
| relevance_scores = last_token_logits[:, relevant_token_id].cpu().tolist() | |
| for j, item in enumerate(batch): | |
| scores[item['doc_id']] = relevance_scores[j] | |
| # Sort doc IDs based on the collected scores | |
| reranked_doc_ids = sorted(scores, key=scores.get, reverse=True) | |
| return [{'doc_id': doc_id, 'score': scores[doc_id]} for doc_id in reranked_doc_ids] | |
| # --- Example Inference Usage --- | |
| top_ranked_docs_subset = pd.DataFrame({'query_id': ['q2'], 'corpus_ids': [['doc1', 'doc2', 'doc3']]}) | |
| corpus_map = dict(zip(corpus_subset['id'], corpus_subset['text'])) | |
| query_id_to_test = 'q2' | |
| query_text_to_test = queries_subset[queries_subset['id'] == query_id_to_test]['query_text'].iloc[0] | |
| candidate_docs = top_ranked_docs_subset[top_ranked_docs_subset['query_id'] == query_id_to_test]['corpus_ids'].iloc[0] | |
| print("\n--- Batched Inference Example ---") | |
| print(f"Query: '{query_text_to_test}'") | |
| print(f"Initial candidate docs: {candidate_docs}") | |
| reranked_results = rerank_documents_batched(query_text_to_test, candidate_docs, corpus_map, model, tokenizer, batch_size=4) | |
| print("\nReranked documents (higher score is better):") | |
| for result in reranked_results: | |
| print(f" Doc ID: {result['doc_id']}, Score: {result['score']:.4f}") | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # !pip install transformers datasets peft trl bitsandbytes accelerate pandas -q | |
| import torch | |
| import json | |
| import random | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Mxfp4Config | |
| from peft import LoraConfig, get_peft_model, PeftModel | |
| from trl import SFTTrainer | |
| # --- 1. Data Preparation with Instruction Tuning --- | |
| def create_instructional_prompt(query, document, label=None): | |
| """Creates a formatted prompt from a template for the reranking task.""" | |
| # Truncate document to a manageable size for training/inference | |
| max_doc_tokens = 512 | |
| doc_tokens = document.split() | |
| truncated_doc = " ".join(doc_tokens[:max_doc_tokens]) | |
| templates = [ | |
| f"Query: {query}\n\nDocument: {truncated_doc}\n\nIs this document relevant to the query? Answer:", | |
| f"Assess the relevance of the following document for the query provided.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nRelevance:", | |
| f"Given the query, determine if the document is relevant.\n\nQuery: {query}\n\nDocument: {truncated_doc}\n\nDecision:", | |
| ] | |
| # Choose a random template for each example to improve model generalization | |
| prompt = random.choice(templates) | |
| if label is not None: | |
| return f"{prompt} {label}" | |
| return prompt | |
| def prepare_reranking_dataset_from_json(json_path): | |
| """ | |
| Processes JSON file with pre-labeled supporting and non-supporting contexts. | |
| Expected JSON structure (list of objects): | |
| [ | |
| { | |
| "question": "A specific question", | |
| "answers": ["A detailed answer"], | |
| "supporting_contexts": [ | |
| {"title": "...", "text": "..."}, | |
| ... | |
| ], | |
| "non_supporting_contexts": [ | |
| {"title": "...", "text": "..."}, | |
| ... | |
| ] | |
| }, | |
| ... | |
| ] | |
| Args: | |
| json_path (str): Path to the JSON file | |
| Returns: | |
| datasets.Dataset: A Hugging Face Dataset ready for SFTTrainer. | |
| """ | |
| # Load the JSON file | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| # Handle both single object and list of objects | |
| if isinstance(data, dict): | |
| data = [data] | |
| training_examples = [] | |
| print(f"Processing {len(data)} questions from JSON...") | |
| for item in data: | |
| question = item.get("question", "") | |
| supporting_contexts = item.get("supporting_contexts", []) | |
| non_supporting_contexts = item.get("non_supporting_contexts", []) | |
| if not question.strip(): | |
| continue | |
| # Create positive examples from supporting contexts | |
| for ctx in supporting_contexts: | |
| text = ctx.get("text", "") | |
| if text.strip(): | |
| full_text = create_instructional_prompt(question, text, label="Relevant") | |
| training_examples.append({"text": full_text}) | |
| # Create negative examples from non-supporting contexts | |
| for ctx in non_supporting_contexts: | |
| text = ctx.get("text", "") | |
| if text.strip(): | |
| full_text = create_instructional_prompt(question, text, label="Not Relevant") | |
| training_examples.append({"text": full_text}) | |
| # Shuffle the combined dataset | |
| random.shuffle(training_examples) | |
| print(f"Created {len(training_examples)} total training examples") | |
| # Count positives and negatives | |
| positives = sum(1 for ex in training_examples if "Relevant" in ex["text"] and "Not Relevant" not in ex["text"]) | |
| negatives = len(training_examples) - positives | |
| print(f" - Positive examples (Relevant): {positives}") | |
| print(f" - Negative examples (Not Relevant): {negatives}") | |
| print(f" - Ratio (Pos:Neg): 1:{negatives/positives:.2f}" if positives > 0 else "") | |
| return Dataset.from_list(training_examples) | |
| # --- Load and prepare dataset from JSON --- | |
| json_file_path = "your_training_data.json" # Update this path | |
| print("Preparing dataset from JSON...") | |
| train_dataset = prepare_reranking_dataset_from_json(json_file_path) | |
| print(f"\nDataset prepared. Total examples: {len(train_dataset)}") | |
| print(f"\n--- Example data point 1 ---\n{train_dataset[0]['text']}") | |
| print(f"\n--- Example data point 2 ---\n{train_dataset[1]['text']}") | |
| # Split into train and validation | |
| train_val_split = train_dataset.train_test_split(test_size=0.1, seed=42) | |
| train_dataset = train_val_split["train"] | |
| eval_dataset = train_val_split["test"] | |
| print(f"\nTraining examples: {len(train_dataset)}") | |
| print(f"Validation examples: {len(eval_dataset)}") | |
| # --- 2. Model and Tokenizer Setup --- | |
| model_name = "openai/gpt-oss-20b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.add_special_tokens({'additional_special_tokens': ['Relevant', 'Not Relevant']}) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| quant_config = Mxfp4Config(dequantize=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| quantization_config=quant_config, | |
| device_map="auto", | |
| use_cache=False | |
| ) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| # --- 3. Training --- | |
| training_args = TrainingArguments( | |
| output_dir="gpt-oss-20b-reranker-json", | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=8, | |
| num_train_epochs=1, | |
| learning_rate=1e-5, | |
| logging_steps=10, | |
| bf16=True, | |
| save_strategy="epoch", | |
| evaluation_strategy="steps", | |
| eval_steps=100, | |
| save_total_limit=2, | |
| report_to="none" | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| tokenizer=tokenizer, | |
| dataset_text_field="text", | |
| max_seq_length=1024 | |
| ) | |
| trainer.train() | |
| trainer.save_model() | |
| # --- 4. BATCHED Inference for Reranking --- | |
| adapter_path = "gpt-oss-20b-reranker-json" | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| model = PeftModel.from_pretrained(base_model, adapter_path) | |
| model = model.merge_and_unload() | |
| model.eval() | |
| relevant_token_id = tokenizer.convert_tokens_to_ids("Relevant") | |
| def rerank_documents_batched(query_text, documents, model, tokenizer, batch_size=4): | |
| """ | |
| Reranks documents in batches for a given query to maximize GPU utilization. | |
| Args: | |
| query_text (str): The search query | |
| documents (list): List of document texts or dicts with 'text' key to rerank | |
| model: The fine-tuned model | |
| tokenizer: The tokenizer | |
| batch_size (int): Batch size for inference | |
| Returns: | |
| list: Documents ranked by relevance score | |
| """ | |
| model.eval() | |
| scores = {} | |
| # Prepare all prompts first | |
| prompts = [] | |
| for idx, doc in enumerate(documents): | |
| # Handle both string documents and dict documents | |
| if isinstance(doc, dict): | |
| doc_text = doc.get('text', '') | |
| doc_title = doc.get('title', f'Document {idx+1}') | |
| else: | |
| doc_text = doc | |
| doc_title = f'Document {idx+1}' | |
| if not doc_text or not doc_text.strip(): | |
| scores[idx] = -float('inf') # Penalize empty docs | |
| else: | |
| prompts.append({ | |
| 'idx': idx, | |
| 'title': doc_title, | |
| 'text': doc_text, | |
| 'prompt': create_instructional_prompt(query_text, doc_text) | |
| }) | |
| with torch.no_grad(): | |
| for i in range(0, len(prompts), batch_size): | |
| batch = prompts[i:i+batch_size] | |
| batch_prompts = [item['prompt'] for item in batch] | |
| # Tokenize the batch of prompts | |
| inputs = tokenizer( | |
| batch_prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=1024 | |
| ).to(model.device) | |
| # Get logits | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get the logits for the last token of each sequence | |
| last_token_logits = logits[:, -1, :] | |
| # Extract the "Relevant" token's logit for each item | |
| relevance_scores = last_token_logits[:, relevant_token_id].cpu().tolist() | |
| for j, item in enumerate(batch): | |
| scores[item['idx']] = relevance_scores[j] | |
| # Sort documents by score | |
| ranked_indices = sorted(scores.keys(), key=lambda x: scores[x], reverse=True) | |
| return [ | |
| { | |
| 'title': prompts[idx]['title'] if idx < len(prompts) else f'Document {idx+1}', | |
| 'text': prompts[idx]['text'] if idx < len(prompts) else documents[idx], | |
| 'score': scores[idx], | |
| 'rank': rank + 1 | |
| } | |
| for rank, idx in enumerate(ranked_indices) | |
| ] | |
| # --- Example Inference Usage --- | |
| print("\n--- Batched Inference Example ---") | |
| # Example: Load test data from JSON format | |
| test_query = "What area did NVIDIA initially focus on before expanding to other fields?" | |
| test_documents = [ | |
| { | |
| "title": "NVIDIA Company History", | |
| "text": "Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields." | |
| }, | |
| { | |
| "title": "GPU Deep Learning Applications", | |
| "text": "Some of the most recent applications of GPU-powered deep learning include recommendation systems, large language models, and generative AI." | |
| }, | |
| { | |
| "title": "Consumer Electronics", | |
| "text": "The company manufactures various consumer electronics products including smartphones and tablets." | |
| }, | |
| { | |
| "title": "NVIDIA Origins", | |
| "text": "NVIDIA initially focused on PC graphics before expanding to AI, data centers, and autonomous vehicles." | |
| } | |
| ] | |
| print(f"Query: '{test_query}'") | |
| print(f"Number of candidate documents: {len(test_documents)}\n") | |
| reranked_results = rerank_documents_batched( | |
| test_query, | |
| test_documents, | |
| model, | |
| tokenizer, | |
| batch_size=4 | |
| ) | |
| print("Reranked documents (higher score is better):") | |
| for result in reranked_results: | |
| doc_preview = result['text'][:100] + "..." if len(result['text']) > 100 else result['text'] | |
| print(f"\n{'='*80}") | |
| print(f"Rank {result['rank']}: Score: {result['score']:.4f}") | |
| print(f"Title: {result['title']}") | |
| print(f"Text: {doc_preview}") | |
| # --- Additional Helper: Load and test with actual JSON structure --- | |
| def test_with_json_file(json_path, model, tokenizer, question_index=0, top_k=5): | |
| """ | |
| Test the reranker with actual data from your JSON file. | |
| Args: | |
| json_path: Path to your JSON file | |
| model: Trained model | |
| tokenizer: Tokenizer | |
| question_index: Which question to test (default first one) | |
| top_k: Number of top results to show | |
| """ | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| if isinstance(data, dict): | |
| data = [data] | |
| if question_index >= len(data): | |
| print(f"Question index {question_index} out of range. Only {len(data)} questions available.") | |
| return | |
| item = data[question_index] | |
| query = item["question"] | |
| # Combine all contexts for reranking | |
| all_contexts = [] | |
| all_contexts.extend([{"text": ctx["text"], "title": ctx["title"], "label": "SUPPORTING"} | |
| for ctx in item.get("supporting_contexts", [])]) | |
| all_contexts.extend([{"text": ctx["text"], "title": ctx["title"], "label": "NON-SUPPORTING"} | |
| for ctx in item.get("non_supporting_contexts", [])]) | |
| print(f"\n{'='*80}") | |
| print(f"TESTING WITH REAL DATA") | |
| print(f"{'='*80}") | |
| print(f"Question: {query}") | |
| print(f"Total contexts to rerank: {len(all_contexts)}") | |
| print(f" - Supporting: {len(item.get('supporting_contexts', []))}") | |
| print(f" - Non-supporting: {len(item.get('non_supporting_contexts', []))}") | |
| results = rerank_documents_batched(query, all_contexts, model, tokenizer, batch_size=4) | |
| print(f"\n{'='*80}") | |
| print(f"TOP {top_k} RERANKED RESULTS:") | |
| print(f"{'='*80}") | |
| for i, result in enumerate(results[:top_k]): | |
| original_label = all_contexts[i].get("label", "UNKNOWN") | |
| print(f"\nRank {result['rank']}: Score: {result['score']:.4f} | Original Label: {original_label}") | |
| print(f"Title: {result['title']}") | |
| print(f"Text: {result['text'][:150]}...") | |
| # Uncomment to test with your actual JSON file: | |
| # test_with_json_file("your_training_data.json", model, tokenizer, question_index=0, top_k=5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # !pip install transformers==4.55.0 datasets peft trl bitsandbytes openai-harmony tiktoken accelerate | |
| # Assume data_list is a list of dicts each with a "messages" field (list of messages as above) | |
| # train_dataset = Dataset.from_list([ {"messages": conv["messages"]} for conv in training_conversations ]) | |
| # val_dataset = Dataset.from_list([ {"messages": conv["messages"]} for conv in validation_conversations ]) | |
| import torch | |
| from transformers import AutoModelForCausalLM, Mxfp4Config | |
| model_name = "openai/gpt-oss-20b" | |
| quant_config = Mxfp4Config(dequantize=True) # dequantize 4-bit weights to float for training | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| quantization_config=quant_config, | |
| device_map="auto", # automatically put layers on available GPU(s) | |
| use_cache=False # disable cache for training (gradient checkpointing will be used) | |
| ) | |
| from peft import LoraConfig, get_peft_model | |
| peft_config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| target_modules=["all"], # "all-linear" can be used in latest peft to auto-detect linear layers | |
| # target_parameters=[ | |
| # "7.mlp.experts.gate_up_proj", | |
| # "7.mlp.experts.down_proj", | |
| # "15.mlp.experts.gate_up_proj", | |
| # "15.mlp.experts.down_proj", | |
| # "23.mlp.experts.gate_up_proj", | |
| # "23.mlp.experts.down_proj", | |
| # ], | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| from trl import SFTTrainer, SFTConfig | |
| from transformers import TrainingArguments | |
| # Define training hyperparameters | |
| training_args = TrainingArguments( | |
| output_dir="gpt-oss-20b-healthcare-checkpoints", | |
| per_device_train_batch_size=1, # we use batch size 1 per GPU to maximize sequence length usage | |
| gradient_accumulation_steps=4, # accumulate to simulate batch size 4 if needed | |
| num_train_epochs=2, | |
| learning_rate=2e-4, | |
| logging_steps=50, | |
| bf16=True, # use bfloat16 mixed precision | |
| save_strategy="epoch", | |
| report_to="none" # (or "wandb" if using Weights & Biases for logging) | |
| ) | |
| # training_args = SFTConfig( | |
| # learning_rate=2e-4, | |
| # gradient_checkpointing=True, | |
| # num_train_epochs=1, | |
| # logging_steps=10, | |
| # bf16=True, | |
| # per_device_train_batch_size=8, | |
| # per_device_eval_batch_size=8, | |
| # gradient_accumulation_steps=2, | |
| # max_length=2048, | |
| # warmup_ratio=0.03, | |
| # eval_strategy="steps", | |
| # eval_steps=10, | |
| # lr_scheduler_type="cosine_with_min_lr", | |
| # lr_scheduler_kwargs={"min_lr_rate": 0.1}, | |
| # output_dir=SAVED_MODEL_ID, | |
| # report_to="tensorboard", | |
| # push_to_hub=True, | |
| # ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| tokenizer=tokenizer | |
| ) | |
| trainer.train() | |
| trainer.save_model() # saves to output_dir | |
| # inference with FT model | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| # Load base model (back in 4-bit quantized form for inference to save memory, if desired) | |
| model_name = "openai/gpt-oss-20b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, # or "auto" | |
| device_map="auto" | |
| ) | |
| # Load the LoRA adapter | |
| adapter_path = "gpt-oss-20b-healthcare-checkpoints" # where we saved our LoRA fine-tuned model | |
| model = PeftModel.from_pretrained(base_model, adapter_path) | |
| model = model.merge_and_unload() # merges LoRA into base_model and returns a standard model | |
| prompt_messages = [ | |
| {"role": "developer", "content": "You are a compassionate doctor AI assistant. Answer medical questions with accurate, helpful information."}, | |
| {"role": "user", "content": "What are the symptoms of diabetes?"} | |
| ] | |
| inputs = tokenizer.apply_chat_template(prompt_messages, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
| output_ids = model.generate(**inputs, max_new_tokens=200, temperature=0.7) | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| print(response) | |
| ## streaming the output | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| # Prepare the streamer | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # We will generate in a background thread to not block the main thread | |
| gen_kwargs = dict(**inputs, max_new_tokens=200, temperature=0.7, streamer=streamer) | |
| thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| print("Assistant:", end=" ", flush=True) | |
| # Read the generated text as it streams | |
| for token_text in streamer: | |
| print(token_text, end="", flush=True) | |
Author
Author
Dataset Format:
[
{
"question": "...",
"answers": ["...", "...", ...],
"ctxs": [
{
"id": "...", // Passage ID from database TSV file
"title": "...", // Passage title
"text": "...", // Passage full text
"score": "...", // Retriever score
"has_answer": true|false // Whether the passage contains the answer
}
]
}
]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ml.p4de.24xlarge - for 20B
ml.p5en.48xlarge - for 120B