Last active
March 27, 2025 10:31
-
-
Save Cozy228/c038643b18c71ae62f1f0811d29bab81 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import getpass | |
| import os | |
| import logging | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from datasets import Dataset | |
| from ragas import evaluate | |
| from ragas.metrics import ( | |
| context_precision, | |
| context_recall, | |
| faithfulness, | |
| answer_relevancy, | |
| answer_correctness, | |
| ) | |
| # Ragas v0.1 introduced breaking changes, including context_utilization and answer_similarity | |
| # If using Ragas >= 0.1, you might want different metrics: | |
| # from ragas.metrics import context_utilization, answer_similarity | |
| # DEFAULT_METRICS = [context_precision, faithfulness, answer_relevancy, context_recall, context_utilization, answer_correctness, answer_similarity] | |
| from ragas.llms import LangchainLLMWrapper | |
| from ragas.embeddings import LangchainEmbeddingsWrapper | |
| # Assuming necessary LLM and embedding models are configured elsewhere or passed in | |
| from langchain_together import TogetherEmbeddings | |
| from langchain.chat_models import init_chat_model | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # --- Constants --- | |
| DEFAULT_METRICS = [ | |
| context_precision, | |
| context_recall, | |
| faithfulness, | |
| answer_relevancy, | |
| answer_correctness, | |
| ] | |
| DEFAULT_OUTPUT_DIR = "." # Changed default output dir to current directory | |
| HISTORY_FILE_NAME = "evaluation_history.csv" | |
| # --- Helper Functions --- | |
| def ensure_dir_exists(directory_path): | |
| """Creates a directory if it doesn't exist.""" | |
| if not os.path.exists(directory_path): | |
| try: | |
| os.makedirs(directory_path) | |
| logging.info(f"Created directory: {directory_path}") | |
| except OSError as e: | |
| logging.error(f"Error creating directory {directory_path}: {e}") | |
| raise | |
| def calculate_iqr_thresholds(series, sensitivity=1.5): | |
| """Calculates warning and critical thresholds based on IQR.""" | |
| try: | |
| # Drop NaN values before calculating quantiles to avoid errors | |
| series_cleaned = series.dropna() | |
| if series_cleaned.empty: | |
| logging.warning(f"Cannot calculate IQR thresholds for {series.name}: Series is empty after dropping NaN.") | |
| return 0.6, 0.4 # Default fallback thresholds | |
| q1 = series_cleaned.quantile(0.25) | |
| q3 = series_cleaned.quantile(0.75) | |
| iqr = q3 - q1 | |
| # Handle cases where IQR is zero (e.g., all values are the same) | |
| if iqr == 0: | |
| critical_threshold = q1 - 0.1 # Arbitrary small deviation | |
| warning_threshold = q1 | |
| else: | |
| critical_threshold = q1 - sensitivity * iqr | |
| warning_threshold = q1 # Using Q1 as a simple warning threshold | |
| # Ensure thresholds are within reasonable bounds (e.g., 0 to 1 for scores) | |
| critical_threshold = max(0, critical_threshold) | |
| warning_threshold = max(critical_threshold, min(1, warning_threshold)) # Warning shouldn't be lower than critical | |
| return warning_threshold, critical_threshold | |
| except Exception as e: | |
| logging.warning(f"Could not calculate IQR thresholds for {series.name}: {e}. Returning defaults.") | |
| return 0.6, 0.4 # Default fallback thresholds | |
| # --- Core Functions --- | |
| def load_and_validate_data(data_input, is_multi_turn=False): | |
| """ | |
| Loads data from a file path or uses a pre-loaded Dataset object. | |
| Validates required columns. Returns a Hugging Face Dataset object. | |
| """ | |
| if isinstance(data_input, str): # It's a file path | |
| file_path = data_input | |
| logging.info(f"Loading data from file: {file_path}") | |
| # Placeholder: Implement actual loading and validation logic for different file types if needed | |
| # Example for CSV: | |
| try: | |
| # Determine file type (simple example) | |
| if file_path.lower().endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| elif file_path.lower().endswith('.json') or file_path.lower().endswith('.jsonl'): | |
| df = pd.read_json(file_path, lines=file_path.lower().endswith('.jsonl')) | |
| else: | |
| raise ValueError(f"Unsupported file type: {file_path}. Please use CSV or JSON/JSONL.") | |
| # Basic validation (adjust based on actual format) | |
| required_cols = ['question', 'answer', 'contexts', 'ground_truth'] | |
| if is_multi_turn: | |
| required_cols.append('chat_history') # Add chat_history for multi-turn | |
| logging.info("Multi-turn mode: Validating 'chat_history' column.") | |
| missing_cols = [col for col in required_cols if col not in df.columns] | |
| if missing_cols: | |
| raise ValueError(f"Missing required columns in file {file_path}: {missing_cols}") | |
| # Convert 'contexts' if stored as string representation of list | |
| if not df.empty and isinstance(df['contexts'].iloc[0], str): | |
| try: | |
| # Attempt to evaluate the string representation safely | |
| import ast | |
| df['contexts'] = df['contexts'].apply(ast.literal_eval) | |
| except (ValueError, SyntaxError) as e: | |
| logging.error(f"Error parsing 'contexts' column in file {file_path}. Ensure it's a valid list representation: {e}") | |
| raise ValueError("Invalid format for 'contexts' column.") | |
| # Ensure contexts is a list of strings | |
| if not df.empty and not all(isinstance(ctx, list) and all(isinstance(s, str) for s in ctx) for ctx in df['contexts']): | |
| raise ValueError("'contexts' column must contain lists of strings.") | |
| # Validate chat_history format if multi-turn | |
| if is_multi_turn and not df.empty: | |
| if 'chat_history' not in df.columns: | |
| raise ValueError("Missing 'chat_history' column for multi-turn data.") | |
| # Basic format check: is it a list of dicts with role/content? | |
| def validate_history(history): | |
| # Allow NaN/None for rows that might not have history (though ideally filtered before) | |
| if pd.isna(history): return True | |
| if not isinstance(history, list): return False | |
| return all(isinstance(turn, dict) and 'role' in turn and 'content' in turn for turn in history) | |
| # Handle potential string representation of list | |
| if isinstance(df['chat_history'].iloc[0], str): | |
| try: | |
| import ast | |
| # Use apply with a lambda to handle potential NaN values gracefully | |
| df['chat_history'] = df['chat_history'].apply(lambda x: ast.literal_eval(x) if pd.notna(x) and isinstance(x, str) else x) | |
| except (ValueError, SyntaxError) as e: | |
| logging.error(f"Error parsing 'chat_history' column in file {file_path}. Ensure it's a valid list of dicts representation: {e}") | |
| raise ValueError("Invalid format for 'chat_history' column.") | |
| if not all(validate_history(hist) for hist in df['chat_history']): | |
| raise ValueError("'chat_history' column must contain lists of dictionaries, each with 'role' and 'content' keys.") | |
| dataset = Dataset.from_pandas(df) | |
| logging.info(f"Data loaded successfully from file. Rows: {len(dataset)}") | |
| return dataset | |
| except FileNotFoundError: | |
| logging.error(f"Data file not found: {file_path}") | |
| raise | |
| except Exception as e: | |
| logging.error(f"Error loading or validating data from file {file_path}: {e}") | |
| raise | |
| elif isinstance(data_input, Dataset): | |
| logging.info("Using pre-loaded Dataset object.") | |
| dataset = data_input | |
| # Optional: Add validation for pre-loaded dataset as well | |
| required_cols = ['question', 'answer', 'contexts', 'ground_truth'] | |
| if is_multi_turn: | |
| required_cols.append('chat_history') # Add chat_history for multi-turn | |
| logging.info("Multi-turn mode: Validating 'chat_history' column in provided Dataset.") | |
| missing_cols = [col for col in required_cols if col not in dataset.column_names] | |
| if missing_cols: | |
| raise ValueError(f"Missing required columns in provided Dataset: {missing_cols}") | |
| # Validate chat_history format if multi-turn | |
| if is_multi_turn: | |
| if 'chat_history' not in dataset.column_names: | |
| raise ValueError("Missing 'chat_history' column for multi-turn data in provided Dataset.") | |
| # Basic format check: is it a list of dicts with role/content? | |
| def validate_history(history): | |
| # Allow None/empty lists if dataset was constructed that way | |
| if history is None or history == []: return True | |
| if not isinstance(history, list): return False | |
| return all(isinstance(turn, dict) and 'role' in turn and 'content' in turn for turn in history) | |
| # Check the first element to infer type, assuming consistency | |
| # Need to handle potentially empty dataset | |
| if len(dataset) > 0: | |
| first_history = dataset[0]['chat_history'] | |
| if isinstance(first_history, str): | |
| raise ValueError("Provided Dataset has 'chat_history' as string. Please parse it into a list of dicts before creating the Dataset.") | |
| if not all(validate_history(hist) for hist in dataset['chat_history']): | |
| raise ValueError("Provided Dataset 'chat_history' column must contain lists of dictionaries, each with 'role' and 'content' keys.") | |
| else: | |
| logging.warning("Provided Dataset is empty, skipping chat_history validation.") | |
| logging.info(f"Using provided Dataset. Rows: {len(dataset)}") | |
| return dataset | |
| else: | |
| raise TypeError("data_input must be a file path (str) or a datasets.Dataset object.") | |
| def run_evaluation(dataset, metrics=None, llm=None, embeddings=None): | |
| """Runs the RAGAS evaluation on the dataset using Ragas wrappers.""" | |
| if metrics is None: | |
| metrics = DEFAULT_METRICS | |
| # Wrap the provided Langchain models using Ragas wrappers | |
| ragas_llm = None | |
| ragas_embeddings = None | |
| if llm: | |
| ragas_llm = LangchainLLMWrapper(llm) | |
| logging.info(f"Wrapped LLM: {llm.__class__.__name__}") | |
| else: | |
| # Ragas metrics often require an LLM, so raise an error if not provided. | |
| raise ValueError("LLM must be provided for Ragas evaluation.") | |
| if embeddings: | |
| ragas_embeddings = LangchainEmbeddingsWrapper(embeddings) | |
| logging.info(f"Wrapped Embeddings: {embeddings.__class__.__name__}") | |
| else: | |
| # Ragas metrics often require embeddings, so raise an error if not provided. | |
| raise ValueError("Embeddings must be provided for Ragas evaluation.") | |
| logging.info(f"Running RAGAS evaluation with metrics: {[m.name for m in metrics]}") | |
| try: | |
| # Note: Ragas evaluate might handle multi-turn implicitly if 'chat_history' column exists | |
| # Check Ragas documentation for specific multi-turn metric requirements if issues arise. | |
| result = evaluate( | |
| dataset=dataset, | |
| metrics=metrics, | |
| llm=ragas_llm, # Pass the wrapped LLM | |
| embeddings=ragas_embeddings, # Pass the wrapped embeddings | |
| raise_exceptions=False # Set to True to debug metric errors more easily | |
| ) | |
| logging.info("RAGAS evaluation completed.") | |
| scores_df = result.to_pandas() | |
| return scores_df | |
| except Exception as e: | |
| logging.error(f"RAGAS evaluation failed: {e}", exc_info=True) # Log traceback | |
| # Consider adding more specific error handling based on Ragas exceptions | |
| raise | |
| def calculate_all_thresholds(scores_df, sensitivity=1.5): | |
| """Calculates thresholds for all numeric metric columns in the DataFrame.""" | |
| thresholds = {} | |
| metric_cols = scores_df.select_dtypes(include=np.number).columns | |
| for col in metric_cols: | |
| # Exclude non-metric cols if they somehow got included (shouldn't happen with current logic) | |
| if col not in ['question', 'answer', 'contexts', 'ground_truth', 'chat_history']: | |
| thresholds[col] = calculate_iqr_thresholds(scores_df[col], sensitivity) | |
| logging.info(f"Calculated thresholds: {thresholds}") | |
| return thresholds | |
| def generate_diagnostic_report(scores_df, thresholds): | |
| """Analyzes scores against thresholds and generates a diagnostic report.""" | |
| report = {} | |
| # Ensure index aligns if scores_df index is not default range | |
| scores_df_reset = scores_df.reset_index() | |
| for index, row in scores_df_reset.iterrows(): | |
| diagnostics = [] | |
| for metric, (warn_thr, crit_thr) in thresholds.items(): | |
| # Check if metric exists in the row before accessing | |
| if metric not in row: | |
| logging.warning(f"Metric '{metric}' not found in evaluation results row {index}. Skipping for diagnostics.") | |
| continue | |
| score = row.get(metric) | |
| if score is None or pd.isna(score): | |
| diagnostics.append(f"⚪️ SKIPPED {metric}: No score") | |
| continue # Skip if score is missing | |
| if score < crit_thr: | |
| diagnostics.append(f"🚨 CRITICAL {metric}: {score:.2f} (< {crit_thr:.2f})") | |
| elif score < warn_thr: | |
| diagnostics.append(f"⚠️ WARNING {metric}: {score:.2f} (< {warn_thr:.2f})") | |
| else: | |
| diagnostics.append(f"✅ OK {metric}: {score:.2f}") | |
| report[index] = { | |
| "question": row.get('question', 'N/A'), | |
| "diagnostics": "; ".join(diagnostics) | |
| } | |
| logging.info("Generated diagnostic report.") | |
| return report | |
| def load_historical_data(output_dir=DEFAULT_OUTPUT_DIR): | |
| """Loads historical evaluation data from a CSV file.""" | |
| history_file = os.path.join(output_dir, HISTORY_FILE_NAME) | |
| if os.path.exists(history_file): | |
| try: | |
| history_df = pd.read_csv(history_file) | |
| # Basic validation - check for essential columns maybe? | |
| logging.info(f"Loaded historical data from: {history_file}") | |
| return history_df | |
| except Exception as e: | |
| logging.warning(f"Could not load or parse history file {history_file}: {e}. Proceeding without history.") | |
| return None | |
| else: | |
| logging.info("No historical data file found.") | |
| return None | |
| def save_results(scores_df, output_dir=DEFAULT_OUTPUT_DIR): | |
| """Saves the current evaluation results, appending to history if it exists.""" | |
| ensure_dir_exists(output_dir) | |
| history_file = os.path.join(output_dir, HISTORY_FILE_NAME) | |
| timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S') | |
| current_scores_file = os.path.join(output_dir, f"eval_{timestamp}.csv") | |
| try: | |
| # Ensure chat_history is saved correctly if it exists (e.g., as JSON string) | |
| if 'chat_history' in scores_df.columns: | |
| scores_df_to_save = scores_df.copy() | |
| # Convert list of dicts to JSON string for CSV compatibility | |
| import json | |
| scores_df_to_save['chat_history'] = scores_df_to_save['chat_history'].apply( | |
| lambda x: json.dumps(x) if isinstance(x, list) else x | |
| ) | |
| else: | |
| scores_df_to_save = scores_df | |
| scores_df_to_save.to_csv(current_scores_file, index=False) | |
| logging.info(f"Saved current evaluation results to: {current_scores_file}") | |
| # Append to history | |
| history_exists = os.path.exists(history_file) | |
| scores_df_to_save.to_csv(history_file, mode='a', header=not history_exists, index=False) | |
| if history_exists: | |
| logging.info(f"Appended results to history file: {history_file}") | |
| else: | |
| logging.info(f"Created new history file: {history_file}") | |
| except Exception as e: | |
| logging.error(f"Error saving results: {e}") | |
| # Don't raise here, saving is important but maybe not critical to stop execution | |
| def visualize_results(scores_df, history_df=None, thresholds=None, report=None, output_dir=DEFAULT_OUTPUT_DIR): | |
| """Generates and saves an interactive Plotly grouped bar chart of the results.""" | |
| ensure_dir_exists(output_dir) | |
| fig = go.Figure() | |
| metric_cols = scores_df.select_dtypes(include=np.number).columns | |
| # Ensure we only plot actual metric scores | |
| metric_cols = [col for col in metric_cols if col not in ['question', 'answer', 'contexts', 'ground_truth', 'chat_history']] | |
| # Use question index or a shortened version for x-axis labels | |
| question_labels = [f"Q{i+1}: {q[:30]}..." for i, q in enumerate(scores_df['question'])] | |
| # Add bars for each metric per question | |
| for metric in metric_cols: | |
| hover_texts = [] | |
| # Use index from original scores_df for report lookup | |
| original_indices = scores_df.index | |
| if report: | |
| hover_texts = [ | |
| f"<b>Q:</b> {report[original_indices[i]]['question'][:100]}...<br><b>{metric}:</b> {score:.3f}<br><b>Report:</b> {report[original_indices[i]]['diagnostics']}" | |
| if original_indices[i] in report else f"<b>Q:</b> {scores_df.loc[original_indices[i], 'question'][:100]}...<br><b>{metric}:</b> {score:.3f}<br><b>Report:</b> N/A" | |
| for i, score in enumerate(scores_df[metric]) | |
| ] | |
| else: | |
| hover_texts = [f"<b>Q:</b> {scores_df.loc[original_indices[i], 'question'][:100]}...<br><b>{metric}:</b> {score:.3f}" for i, score in enumerate(scores_df[metric])] | |
| fig.add_trace(go.Bar( | |
| x=question_labels, | |
| y=scores_df[metric], | |
| name=metric, | |
| hoverinfo='text', | |
| hovertext=hover_texts | |
| )) | |
| # Add threshold lines for each metric | |
| if thresholds and metric in thresholds: | |
| warn_thr, crit_thr = thresholds[metric] | |
| fig.add_hline(y=warn_thr, line_dash="dash", line_color="orange", opacity=0.6, | |
| annotation_text=f"{metric} Warn ({warn_thr:.2f})", | |
| annotation_position="bottom right", | |
| annotation_font_size=10) | |
| fig.add_hline(y=crit_thr, line_dash="dot", line_color="red", opacity=0.6, | |
| annotation_text=f"{metric} Crit ({crit_thr:.2f})", | |
| annotation_position="bottom left", | |
| annotation_font_size=10) | |
| # Add historical average lines (optional) - might get cluttered on bar chart | |
| # Consider showing history differently, maybe in tooltips or a separate chart | |
| if history_df is not None and not history_df.empty: | |
| logging.info("Historical averages not plotted on bar chart to reduce clutter.") | |
| # Example: Calculate averages for potential use in hover text or summary | |
| # hist_avg_scores = {} | |
| # for metric in metric_cols: | |
| # if metric in history_df.columns: | |
| # hist_avg_scores[metric] = history_df[metric].mean() | |
| fig.update_layout( | |
| title="RAGAS Evaluation Results per Question", | |
| xaxis_title="Question", | |
| yaxis_title="Score (0-1)", | |
| barmode='group', # Group bars for each question | |
| hovermode="closest", | |
| legend_title="Metrics", | |
| xaxis={'tickangle': -45}, # Angle labels to prevent overlap | |
| margin=dict(b=150) # Increase bottom margin for angled labels | |
| ) | |
| output_file = os.path.join(output_dir, "evaluation_results.html") | |
| try: | |
| fig.write_html(output_file) | |
| logging.info(f"Saved visualization to: {output_file}") | |
| return output_file | |
| except Exception as e: | |
| logging.error(f"Error saving visualization: {e}") | |
| return None | |
| # --- Main Execution Logic --- | |
| def main_evaluation_pipeline(data_input, output_dir=DEFAULT_OUTPUT_DIR, sensitivity=1.5, llm=None, embeddings=None, metrics=None, is_multi_turn=False): | |
| """Orchestrates the full RAGAS evaluation pipeline.""" | |
| try: | |
| # 1. Load Data (handles both path and Dataset object) | |
| dataset = load_and_validate_data(data_input, is_multi_turn=is_multi_turn) | |
| if dataset is None or len(dataset) == 0: | |
| logging.error("Dataset is empty or failed to load. Aborting pipeline.") | |
| return None, None, None | |
| # 2. Load History (before evaluation) | |
| history_df = load_historical_data(output_dir) | |
| # 3. Run Evaluation | |
| scores_df = run_evaluation(dataset, metrics=metrics, llm=llm, embeddings=embeddings) | |
| if scores_df is None or scores_df.empty: | |
| logging.error("Evaluation failed to produce scores. Aborting pipeline.") | |
| return None, None, None | |
| # Add original questions back to the scores DataFrame for reporting/visualization | |
| # Ensure indices align - Ragas evaluate should preserve order | |
| if 'question' in dataset.column_names and len(dataset) == len(scores_df): | |
| scores_df['question'] = dataset['question'] | |
| logging.info("Added 'question' column back to scores DataFrame.") | |
| else: | |
| logging.warning("Could not add 'question' column back to scores DataFrame. Length mismatch or column missing in original dataset.") | |
| # Add a placeholder if needed for downstream functions, though they might fail | |
| if 'question' not in scores_df.columns: | |
| scores_df['question'] = [f"Question {i+1}" for i in range(len(scores_df))] | |
| # 4. Calculate Thresholds (using combined data if available for stability) | |
| # Ensure history_df columns align with scores_df before concatenating | |
| if history_df is not None: | |
| # Select only columns present in both, prioritizing scores_df columns | |
| common_cols = [col for col in scores_df.columns if col in history_df.columns and scores_df[col].dtype == history_df[col].dtype] | |
| combined_df = pd.concat([history_df[common_cols], scores_df[common_cols]], ignore_index=True) | |
| else: | |
| combined_df = scores_df | |
| thresholds = calculate_all_thresholds(combined_df, sensitivity) | |
| # 5. Generate Diagnostics for current run | |
| report = generate_diagnostic_report(scores_df, thresholds) | |
| # 6. Visualize Results | |
| viz_path = visualize_results(scores_df, history_df=history_df, thresholds=thresholds, report=report, output_dir=output_dir) | |
| # 7. Save Results (after successful run) | |
| save_results(scores_df, output_dir) | |
| logging.info("Evaluation pipeline completed successfully.") | |
| if viz_path: | |
| logging.info(f"View interactive results: {viz_path}") | |
| return scores_df, report, viz_path | |
| except Exception as e: | |
| logging.error(f"Evaluation pipeline failed: {e}", exc_info=True) # Log traceback | |
| return None, None, None | |
| if __name__ == "__main__": | |
| # --- Configuration --- | |
| OUTPUT_DIRECTORY = "." # Output to current directory by default | |
| SENSITIVITY_FACTOR = 1.5 # Adjust for stricter/looser anomaly detection | |
| EVALUATE_MULTI_TURN = False # <<< SET THIS FLAG: True to evaluate multi-turn, False for single-turn | |
| os.environ["MISTRAL_API_KEY"] = "" | |
| os.environ["TOGETHER_API_KEY"] = "" | |
| # --- Define Sample Data --- | |
| # Single-Turn Examples | |
| single_turn_q1 = { | |
| 'question': ["Who painted the Mona Lisa?"], | |
| 'answer': ["The Mona Lisa was painted by Leonardo da Vinci."], | |
| 'contexts': [ | |
| ["Leonardo da Vinci was an Italian polymath of the High Renaissance.", "The Mona Lisa is a half-length portrait painting by Leonardo da Vinci.", "It is housed in the Louvre Museum in Paris."] | |
| ], | |
| 'ground_truth': ["The Mona Lisa was painted by the Italian Renaissance artist Leonardo da Vinci between 1503 and 1506."] | |
| } | |
| single_turn_q2 = { | |
| 'question': ["How do you make a simple pasta sauce?"], | |
| 'answer': ["To make a simple pasta sauce, sauté garlic and onions, add crushed tomatoes, season with salt, pepper, and herbs like basil or oregano, and simmer."], | |
| 'contexts': [ | |
| ["A basic tomato sauce involves cooking down tomatoes with aromatics.", "Common aromatics include garlic, onions, and herbs.", "Simmering allows flavors to meld."] | |
| ], | |
| 'ground_truth': ["A simple pasta sauce typically involves sautéing minced garlic and diced onions in olive oil, adding canned crushed tomatoes (like San Marzano), seasoning with salt, black pepper, and dried oregano or fresh basil, and letting it simmer for at least 15-20 minutes to develop flavor."] | |
| } | |
| single_turn_q3 = { | |
| 'question': ["What is the main idea behind quantum entanglement?"], | |
| 'answer': ["Quantum entanglement is a phenomenon where two or more quantum particles become linked in such a way that they share the same fate, regardless of the distance separating them. Measuring a property of one particle instantaneously influences the correlated property of the other(s)."], | |
| 'contexts': [ | |
| ["Quantum entanglement is a key principle of quantum mechanics.", "Einstein famously called it 'spooky action at a distance'.", "Entangled particles exhibit correlations that cannot be explained by classical physics."] | |
| ], | |
| 'ground_truth': ["Quantum entanglement describes a state where multiple quantum particles are linked, sharing the same quantum state. Measuring a property (like spin) of one particle instantly determines the corresponding property of the other entangled particle(s), no matter how far apart they are. This correlation is stronger than classical correlations."] | |
| } | |
| # Multi-Turn Example(s) | |
| multi_turn_q1 = { | |
| 'question': ["What about its famous clock tower?"], # Follow-up question | |
| 'answer': ["The clock tower, officially named Elizabeth Tower, houses the Great Bell known as Big Ben."], | |
| 'contexts': [ | |
| ["The Palace of Westminster is the meeting place of the Parliament of the United Kingdom.", "Elizabeth Tower is the name of the clock tower of the Palace of Westminster in London.", "Big Ben is the nickname for the Great Bell of the striking clock at the north end of the Palace of Westminster."] | |
| ], | |
| 'ground_truth': ["The famous clock tower at the Palace of Westminster is called Elizabeth Tower. Its Great Bell is nicknamed Big Ben."], | |
| 'chat_history': [ # History leading up to the question | |
| {'role': 'user', 'content': 'Tell me about the Palace of Westminster.'}, | |
| {'role': 'assistant', 'content': 'The Palace of Westminster serves as the meeting place for both the House of Commons and the House of Lords, the two houses of the Parliament of the United Kingdom.'} | |
| ] | |
| } | |
| # Add more multi-turn examples if needed | |
| # multi_turn_q2 = { ... } | |
| # --- Select and Prepare Data Based on Flag --- | |
| IS_MULTI_TURN_DATA = EVALUATE_MULTI_TURN | |
| combined_data = { | |
| 'question': [], 'answer': [], 'contexts': [], 'ground_truth': [] | |
| } | |
| data_to_process = [] | |
| if IS_MULTI_TURN_DATA: | |
| logging.info("Preparing multi-turn evaluation dataset.") | |
| data_to_process = [multi_turn_q1] # Add other multi_turn_qX here | |
| combined_data['chat_history'] = [] # Initialize multi-turn specific column | |
| for q_data in data_to_process: | |
| combined_data['question'].extend(q_data['question']) | |
| combined_data['answer'].extend(q_data['answer']) | |
| combined_data['contexts'].extend(q_data['contexts']) | |
| combined_data['ground_truth'].extend(q_data['ground_truth']) | |
| combined_data['chat_history'].extend(q_data['chat_history']) | |
| else: | |
| logging.info("Preparing single-turn evaluation dataset.") | |
| data_to_process = [single_turn_q1, single_turn_q2, single_turn_q3] # Add other single_turn_qX here | |
| for q_data in data_to_process: | |
| combined_data['question'].extend(q_data['question']) | |
| combined_data['answer'].extend(q_data['answer']) | |
| combined_data['contexts'].extend(q_data['contexts']) | |
| combined_data['ground_truth'].extend(q_data['ground_truth']) | |
| # No chat_history column for single-turn | |
| if not data_to_process: | |
| logging.error("No data selected for evaluation based on EVALUATE_MULTI_TURN flag. Exiting.") | |
| exit() | |
| evaluation_dataset = Dataset.from_dict(combined_data) | |
| logging.info(f"Created evaluation dataset with {len(evaluation_dataset)} samples. Multi-turn: {IS_MULTI_TURN_DATA}") | |
| # --- Configure LLM and Embeddings --- | |
| try: | |
| # Using Together AI as per previous state | |
| if not os.environ.get("TOGETHER_API_KEY"): | |
| os.environ["TOGETHER_API_KEY"] = getpass.getpass("Enter API key for Together AI: ") | |
| my_llm = init_chat_model("mistralai/Mixtral-8x7B-Instruct-v0.1", model_provider="together") | |
| my_embeddings = TogetherEmbeddings( | |
| model="togethercomputer/m2-bert-80M-8k-retrieval", | |
| ) | |
| logging.info("Initialized Together AI models.") | |
| except ImportError: | |
| logging.error("Required Langchain packages (e.g., langchain-together) not installed. Please install them.") | |
| exit() | |
| except Exception as e: | |
| logging.error(f"Failed to initialize LLM/Embedding models. Check API keys and packages. Error: {e}") | |
| exit() | |
| # --- Optional: Select specific metrics --- | |
| # Ragas metrics might differ slightly for multi-turn, check documentation. | |
| # Example: ContextRecall might behave differently or have specific multi-turn versions. | |
| # my_metrics = [faithfulness, answer_relevancy] # Example subset | |
| my_metrics = None # Uses defaults defined in DEFAULT_METRICS | |
| # --- Run Pipeline --- | |
| logging.info(f"Starting evaluation pipeline (Multi-turn: {IS_MULTI_TURN_DATA})...") | |
| scores, diagnostic_report, viz_file_path = main_evaluation_pipeline( | |
| data_input=evaluation_dataset, # Pass the Dataset object | |
| output_dir=OUTPUT_DIRECTORY, | |
| sensitivity=SENSITIVITY_FACTOR, | |
| llm=my_llm, # Pass the Langchain LLM | |
| embeddings=my_embeddings, # Pass the Langchain Embeddings | |
| metrics=my_metrics, | |
| is_multi_turn=IS_MULTI_TURN_DATA # Pass the flag | |
| ) | |
| if scores is not None: | |
| print("\n--- Metric Explanations ---") | |
| print("Context Precision: Signal-to-noise ratio of retrieved contexts.") | |
| print("Context Recall: Ability to retrieve all necessary information.") | |
| print("Faithfulness: Factual consistency of the answer with the contexts.") | |
| print("Answer Relevancy: Relevance of the answer to the question.") | |
| print("Answer Correctness: Factual correctness compared to ground truth.") | |
| print("\n--- Evaluation Scores per Question ---") | |
| # Select and display relevant columns | |
| display_cols = ['question'] + [m.name for m in (my_metrics or DEFAULT_METRICS)] | |
| # Ensure columns exist before trying to display | |
| display_cols = [col for col in display_cols if col in scores.columns] | |
| print(scores[display_cols].to_string()) # Use to_string for better console formatting | |
| # Optionally print full diagnostic report | |
| # print("\n--- Full Diagnostic Report ---") | |
| # for i, item_report in report.items(): | |
| # print(f"Item {i} - Q: {item_report['question']}") | |
| # print(f" Diagnostics: {item_report['diagnostics']}") | |
| if viz_file_path: | |
| print(f"\nInteractive bar chart report saved to: {viz_file_path}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment