Created
May 4, 2025 04:57
-
-
Save usernaamee/43bcf0b273e08c4ea873ce9ab21d67cc to your computer and use it in GitHub Desktop.
RAG on Read the Docs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """ | |
| doc_agent.py: A Python script implementing a Retrieval-Augmented Generation (RAG) | |
| agent specifically designed for querying technical documentation (e.g., from | |
| Read the Docs source files like .rst or .txt). | |
| Purpose: To provide more accurate and context-aware answers to questions | |
| about a specific software package or project than traditional keyword search, | |
| by leveraging an LLM grounded with retrieved documentation snippets. | |
| Features: | |
| - Searches local documentation files (.rst, .txt) for relevant context. | |
| - Uses keyword extraction and scoring, with optional LLM-guided refinement. | |
| - Queries an LLM (local or remote OpenAI-compatible API) with retrieved | |
| context to generate answers based *only* on the documentation. | |
| - Streams LLM responses for better interactivity. | |
| - Configurable via command-line arguments. | |
| """ | |
| import os | |
| import openai # OpenAI client library | |
| import argparse | |
| import re | |
| import sys | |
| import time | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Optional, Set, Tuple # Improved typing clarity | |
| # --- Configuration & Constants --- | |
| # Setup basic logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) # Use a dedicated logger | |
| # Default path for documentation source files (e.g., Sphinx source directory) | |
| DEFAULT_DOCS_PATH = Path('./source') | |
| # Default URL for the OpenAI-compatible LLM endpoint | |
| DEFAULT_LLM_BASE_URL = 'http://localhost:8080/v1' | |
| # Default model name (can be overridden) | |
| DEFAULT_MODEL_NAME = 'local-model' | |
| # --- RAG Parameters --- | |
| # Max characters of documentation context to send to the LLM | |
| MAX_CONTEXT_CHARS = 3800 | |
| # Max character length for a single documentation snippet (paragraph) | |
| MAX_SNIPPET_LEN = 500 | |
| # Max snippets to retrieve in the *initial* search | |
| MAX_INITIAL_SNIPPETS = 10 | |
| # Max snippets to retrieve in the *refined* search (after LLM guidance) | |
| MAX_REFINED_SNIPPETS = 8 | |
| # Timeout for LLM API requests | |
| REQUEST_TIMEOUT = 180.0 | |
| # Max tokens for the LLM guidance response (keep it short) | |
| GUIDANCE_MAX_TOKENS = 100 | |
| # Max tokens for the final generated answer | |
| FINAL_ANSWER_MAX_TOKENS = 1500 | |
| # --- Keyword Extraction --- | |
| # Stop words to ignore in user queries (customize for your domain if needed) | |
| STOP_WORDS = { | |
| "a", "about", "above", "after", "again", "against", "all", "am", "an", "and", | |
| "any", "are", "as", "at", "be", "because", "been", "before", "being", | |
| "below", "between", "both", "but", "by", "can", "cannot", "could", | |
| "did", "do", "does", "doing", "down", "during", "each", "few", "for", | |
| "from", "further", "had", "has", "have", "having", "he", "her", "here", | |
| "hers", "herself", "him", "himself", "his", "how", "i", "if", "in", "into", | |
| "is", "it", "its", "itself", "let", "me", "more", "most", "my", "myself", | |
| "no", "nor", "not", "of", "off", "on", "once", "only", "or", "other", | |
| "ought", "our", "ours", "ourselves", "out", "over", "own", "same", "she", | |
| "should", "so", "some", "such", "than", "that", "the", "their", "theirs", | |
| "them", "themselves", "then", "there", "these", "they", "this", "those", | |
| "through", "to", "too", "under", "until", "up", "very", "was", "we", "were", | |
| "what", "when", "where", "which", "while", "who", "whom", "why", "with", | |
| "would", "you", "your", "yours", "yourself", "yourselves", | |
| # Common query words that are often noise for documentation search | |
| "using", "get", "set", "configure", "run", "install", "find", "tell", "me", "about", | |
| "what's", "whats", "how's", "hows", "show", "example", "examples" | |
| } | |
| # --- Helper Functions --- | |
| def extract_keywords(query: str, existing_keywords: Optional[List[str]] = None) -> List[str]: | |
| """ | |
| Extracts potential keywords from a user query about software documentation. | |
| Filters stop words and short words, combines with existing keywords. | |
| """ | |
| words = re.findall(r'\b\w+\b', query.lower()) | |
| new_keywords = {word for word in words if word not in STOP_WORDS and len(word) > 2} | |
| combined = set(existing_keywords or []) | new_keywords | |
| # Consider adding stemming or lemmatization here for more advanced matching | |
| return list(combined) | |
| def _handle_openai_error(e: Exception, start_time: float, context_msg: str = "LLM request") -> bool: | |
| """Handles common OpenAI API errors, logs them, and returns status.""" | |
| duration = time.time() - start_time | |
| sys.stdout.write("\r" + " " * 80 + "\r") # Clear console line | |
| sys.stdout.flush() | |
| # Log specific error types for better diagnosis | |
| if isinstance(e, openai.APITimeoutError): | |
| logger.error(f"{context_msg} timed out after {duration:.2f}s.") | |
| elif isinstance(e, openai.APIConnectionError): | |
| logger.error(f"Connection error during {context_msg} ({duration:.2f}s): {e}") | |
| elif isinstance(e, openai.RateLimitError): | |
| logger.error(f"Rate limit exceeded during {context_msg} ({duration:.2f}s): {e}") | |
| elif isinstance(e, openai.APIStatusError): | |
| logger.error(f"API status error during {context_msg} ({duration:.2f}s): Status={e.status_code}, Response={e.response}") | |
| elif isinstance(e, openai.APIError): | |
| logger.error(f"Generic API error during {context_msg} ({duration:.2f}s): {e}") | |
| else: | |
| logger.error(f"Unexpected error during {context_msg} ({duration:.2f}s): {e}", exc_info=True) | |
| return False # Unexpected failure | |
| return True # Expected/handled API failure | |
| # --- Document Agent Class --- | |
| class DocAgent: | |
| """ | |
| Agent orchestrating the RAG process for querying documentation. | |
| Searches local files, interacts with LLM for guidance and final answers. | |
| """ | |
| def __init__(self, docs_path: Path, base_url: str, api_key: str, model_name: str): | |
| """ | |
| Initializes the documentation Q&A agent. | |
| Args: | |
| docs_path: Path to the root directory containing documentation | |
| source files (e.g., .rst, .txt). | |
| base_url: The base URL of the OpenAI-compatible API endpoint. | |
| api_key: The API key for the endpoint. | |
| model_name: The name of the model to use. | |
| """ | |
| self.docs_path = docs_path | |
| self.model_name = model_name | |
| self.doc_files = self._find_doc_files() # Find documentation files | |
| if not self.doc_files: | |
| raise ValueError(f"No documentation files (.rst, .txt) found in {self.docs_path}. Check the --docs path.") | |
| logger.info(f"Initializing OpenAI client: URL='{base_url}', Model='{model_name}'") | |
| try: | |
| self.client = openai.OpenAI( | |
| base_url=base_url, | |
| api_key=api_key, | |
| timeout=REQUEST_TIMEOUT, | |
| ) | |
| logger.info("OpenAI client initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Error initializing OpenAI client: {e}", exc_info=True) | |
| raise RuntimeError(f"Failed to initialize OpenAI client. Check API URL ('{base_url}') and ensure the server is running.") from e | |
| def _find_doc_files(self) -> List[Path]: | |
| """Recursively finds documentation files (.rst, .txt)""" | |
| if not self.docs_path.is_dir(): | |
| logger.error(f"Documentation directory not found: {self.docs_path}") | |
| return [] | |
| logger.info(f"Scanning for .rst and .txt files in '{self.docs_path}'...") | |
| rst_files = list(self.docs_path.rglob('*.rst')) | |
| txt_files = list(self.docs_path.rglob('*.txt')) | |
| # Add more file types here if needed (e.g., .md) | |
| doc_files = [f for f in rst_files + txt_files if f.is_file()] | |
| logger.info(f"Found {len(doc_files)} documentation files.") | |
| return doc_files | |
| def _search_single_file(self, file_path: Path, query_keywords: Set[str]) -> List[Tuple[int, str]]: | |
| """Searches a single documentation file for keywords, returns scored snippets.""" | |
| scored_paras = [] | |
| try: | |
| # Handle potential encoding issues in documentation files | |
| try: | |
| content = file_path.read_text(encoding='utf-8') | |
| except UnicodeDecodeError: | |
| content = file_path.read_text(encoding='latin-1') # Fallback | |
| logger.warning(f"Used latin-1 fallback encoding for '{file_path.name}'") | |
| # Split by paragraph (common in .rst/.txt docs). Consider more robust chunking for complex docs. | |
| paragraphs = re.split(r'\n\s*\n+', content) | |
| for para in paragraphs: | |
| para_strip = para.strip() | |
| # Basic filtering of paragraphs | |
| if not para_strip or len(para_strip) < 15: continue # Skip very short lines/paragraphs | |
| para_lower = para_strip.lower() | |
| found_keywords = {kw for kw in query_keywords if kw in para_lower} | |
| score = len(found_keywords) | |
| if score > 0: | |
| # Optional: Add proximity bonus if multiple keywords are close | |
| if score > 1: | |
| indices = [para_lower.find(kw) for kw in found_keywords] | |
| valid_indices = [i for i in indices if i != -1] | |
| if len(valid_indices) > 1: | |
| span = max(valid_indices) - min(valid_indices) | |
| if span < MAX_SNIPPET_LEN * 1.5: score += 1 | |
| # Extract snippet centered around the first keyword | |
| try: | |
| first_kw_index = min(idx for kw in found_keywords if (idx := para_lower.find(kw)) != -1) | |
| first_kw = next(kw for kw in found_keywords if para_lower.find(kw) == first_kw_index) | |
| start = max(0, first_kw_index - MAX_SNIPPET_LEN // 3) | |
| end = min(len(para_strip), first_kw_index + len(first_kw) + (2 * MAX_SNIPPET_LEN // 3)) | |
| snippet = para_strip[start:end] | |
| # Add ellipses if snippet is cut off | |
| prefix = "..." if start > 0 else "" | |
| suffix = "..." if end < len(para_strip) else "" | |
| snippet = prefix + snippet.strip() + suffix | |
| if len(snippet) > MAX_SNIPPET_LEN + 6: # Final trim | |
| snippet = snippet[:MAX_SNIPPET_LEN+3] + "..." | |
| scored_paras.append((score, snippet)) | |
| except (StopIteration, ValueError): | |
| logger.debug(f"Minor issue extracting snippet from {file_path.name}") | |
| # Fallback: use paragraph if short enough | |
| if len(para_strip) <= MAX_SNIPPET_LEN: scored_paras.append((score, para_strip)) | |
| except Exception as e_snip: | |
| logger.warning(f"Unexpected error during snippet extraction in {file_path.name}: {e_snip}") | |
| except Exception as e_read: | |
| logger.warning(f"Error reading/processing file {file_path}: {e_read}") | |
| return scored_paras | |
| def _jaccard_similarity(self, set1: Set[str], set2: Set[str]) -> float: | |
| """Calculates Jaccard similarity between two sets of words (for deduplication).""" | |
| intersection = len(set1.intersection(set2)) | |
| union = len(set1.union(set2)) | |
| if union == 0: return 1.0 if not set1 and not set2 else 0.0 | |
| return intersection / union | |
| def search_files(self, query_keywords: List[str], max_snippets_to_return: int) -> str: | |
| """ | |
| Searches documentation files, ranks paragraphs/snippets, deduplicates, | |
| and returns a formatted context string. | |
| """ | |
| keyword_set = set(query_keywords) | |
| logger.info(f"Searching {len(self.doc_files)} files for keywords: {keyword_set}") | |
| all_scored_snippets: List[Tuple[int, Path, str]] = [] | |
| total_files = len(self.doc_files) | |
| # --- File Iteration & Snippet Collection --- | |
| for i, file_path in enumerate(self.doc_files): | |
| # Progress indicator | |
| if (i + 1) % 20 == 0 or (i + 1) == total_files: | |
| progress = int(100 * (i + 1) / total_files) | |
| sys.stdout.write(f"\rSearching... [{i+1}/{total_files} files ({progress}%)]") | |
| sys.stdout.flush() | |
| file_snippets = self._search_single_file(file_path, keyword_set) | |
| for score, snippet in file_snippets: | |
| all_scored_snippets.append((score, file_path, snippet)) | |
| sys.stdout.write("\r" + " " * 80 + "\r") # Clear progress line | |
| sys.stdout.flush() | |
| logger.info(f"Finished searching. Found {len(all_scored_snippets)} potential snippets.") | |
| # --- Ranking & Deduplication --- | |
| all_scored_snippets.sort(key=lambda x: (-x[0], x[1].as_posix())) # Sort by score desc, path asc | |
| final_context = "" | |
| total_chars = 0 | |
| snippet_count = 0 | |
| added_snippet_norm_hashes: Set[int] = set() # For exact duplicates | |
| added_snippet_word_sets: List[Set[str]] = [] # For near-duplicates (Jaccard) | |
| included_files: Set[Path] = set() # Track files mentioned in context | |
| for score, file_path, snippet in all_scored_snippets: | |
| if snippet_count >= max_snippets_to_return or total_chars >= MAX_CONTEXT_CHARS: | |
| break # Stop if limits reached | |
| # Normalize for robust duplicate checking | |
| normalized_snippet = ' '.join(re.findall(r'\b\w+\b', snippet.lower())) | |
| if not normalized_snippet: continue | |
| # Check exact duplicate (hash) | |
| norm_hash = hash(normalized_snippet) | |
| if norm_hash in added_snippet_norm_hashes: continue | |
| # Check near-duplicate (Jaccard) | |
| current_word_set = set(normalized_snippet.split()) | |
| is_near_duplicate = any(self._jaccard_similarity(current_word_set, existing_set) > 0.8 for existing_set in added_snippet_word_sets) | |
| if is_near_duplicate: | |
| logger.debug(f"Skipping near-duplicate snippet from {file_path.name}") | |
| continue | |
| # --- Assemble Context --- | |
| try: | |
| relative_path = os.path.relpath(file_path, self.docs_path.parent) | |
| except ValueError: | |
| relative_path = file_path.as_posix() # Fallback path | |
| file_marker = f"\n--- Context from: {relative_path} ---\n" | |
| context_to_add = "" | |
| if file_path not in included_files: | |
| context_to_add += file_marker | |
| included_files.add(file_path) | |
| context_to_add += snippet.strip() + "\n" | |
| # Add if within character limits | |
| estimated_len = len(context_to_add) | |
| if total_chars + estimated_len <= MAX_CONTEXT_CHARS: | |
| final_context += context_to_add | |
| total_chars += estimated_len | |
| snippet_count += 1 | |
| added_snippet_norm_hashes.add(norm_hash) | |
| added_snippet_word_sets.append(current_word_set) | |
| else: | |
| logger.debug(f"Next snippet exceeds context limit. Stopping.") | |
| break | |
| logger.info(f"Selected {snippet_count} unique snippets ({total_chars} chars) for context.") | |
| return final_context.strip() | |
| def ask_llm_for_guidance(self, user_query: str, context: str) -> Optional[str]: | |
| """ | |
| Asks LLM if context is sufficient or if search needs refinement | |
| using more specific keywords from the documentation context. | |
| """ | |
| # --- System Prompt: Focus on Documentation Search Assistance --- | |
| system_prompt = ( | |
| "You are an expert search assistant for software documentation. Analyze the user's QUESTION " | |
| "and the provided CONTEXT snippets retrieved from the documentation.\n" | |
| "Determine if the CONTEXT likely contains enough information to answer the QUESTION, " | |
| "or if a more focused search using different keywords (like specific function names, parameters, config keys mentioned in the question but maybe missing in context) would be better.\n" | |
| "Your response MUST be in one of the following two formats ONLY:\n" | |
| "1. If context seems sufficient: CONTEXT_SUFFICIENT\n" | |
| "2. If context is insufficient or needs refinement: SEARCH_FOR: keyword1 keyword2 keyword3\n" | |
| "Provide 1 to 5 specific, relevant keywords if refinement is needed. Focus on technical terms.\n" | |
| "Do NOT provide explanations or any other text." | |
| ) | |
| prompt_messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"CONTEXT:\n```\n{context or '[No context found]'}\n```\n\nQUESTION:\n{user_query}"} | |
| ] | |
| logger.info("Asking LLM for Search Guidance...") | |
| sys.stdout.write("Waiting for guidance LLM...") | |
| sys.stdout.flush() | |
| accumulated_content = "" | |
| stream_started = False | |
| start_time = time.time() | |
| try: | |
| stream = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=prompt_messages, | |
| temperature=0.1, # Low temp for structured output | |
| max_tokens=GUIDANCE_MAX_TOKENS, | |
| stream=True, | |
| stop=["\n"] # Encourage single-line response | |
| ) | |
| # Accumulate streamed response | |
| for chunk in stream: | |
| if not stream_started: | |
| duration = time.time() - start_time | |
| sys.stdout.write(f"\rGuidance stream started ({duration:.2f}s). Receiving... ") | |
| sys.stdout.flush() | |
| stream_started = True | |
| content_piece = chunk.choices[0].delta.content | |
| if content_piece: | |
| accumulated_content += content_piece | |
| sys.stdout.write(".") | |
| sys.stdout.flush() | |
| # --- Process Guidance Response --- | |
| duration = time.time() - start_time | |
| sys.stdout.write(f"\rGuidance stream finished ({duration:.2f}s). Processing...{' '*20}\n") | |
| sys.stdout.flush() | |
| if not accumulated_content: | |
| logger.warning("LLM guidance stream provided no content.") | |
| return None | |
| cleaned_content = accumulated_content.strip() | |
| logger.debug(f"LLM Raw Guidance Response: '{cleaned_content}'") | |
| # Parse the structured response | |
| if cleaned_content.startswith("SEARCH_FOR:"): | |
| keywords_str = cleaned_content[len("SEARCH_FOR:"):].strip() | |
| keywords_str = re.sub(r'[.,;!?]$', '', keywords_str).strip() | |
| keywords_str = ' '.join(keywords_str.split()) # Normalize spaces | |
| if keywords_str: | |
| logger.info(f"LLM suggests refining search with keywords: '{keywords_str}'") | |
| return keywords_str | |
| else: | |
| logger.warning("LLM guidance 'SEARCH_FOR:' but no keywords provided.") | |
| return None | |
| elif "CONTEXT_SUFFICIENT" in cleaned_content: | |
| logger.info("LLM guidance indicates context is likely sufficient.") | |
| return None | |
| else: | |
| logger.warning(f"LLM guidance response format unexpected: '{cleaned_content}'. Proceeding with initial context.") | |
| return None | |
| except Exception as e: | |
| _handle_openai_error(e, start_time, "LLM guidance request") | |
| return None # Indicate failure | |
| def ask_llm_for_final_answer(self, user_query: str, context: str) -> bool: | |
| """ | |
| Queries LLM with final context to generate an answer grounded | |
| *only* in the provided documentation snippets. Streams the response. | |
| """ | |
| # --- System Prompt: Grounding the LLM as a Documentation Specialist --- | |
| system_prompt = ( | |
| "You are a specialist assistant knowledgeable about a specific software package, based *only* on the documentation provided.\n" | |
| "Carefully analyze the user's QUESTION and the provided CONTEXT snippets from the documentation.\n" | |
| "Answer the user's QUESTION accurately and concisely using *only* the information present in the CONTEXT.\n" | |
| "If the CONTEXT does not contain information to answer the question or is empty, you MUST explicitly state that the information is not available in the provided snippets.\n" | |
| "Do not make assumptions or use external knowledge.\n" | |
| "Be direct. Do not include conversational filler like 'Based on the provided context...'.\n" | |
| "If providing code examples from the context, ensure they are formatted correctly using markdown." | |
| ) | |
| prompt_messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"CONTEXT:\n```\n{context or '[No relevant context found after search]'}\n```\n\nQUESTION:\n{user_query}"} | |
| ] | |
| if not context.strip(): | |
| logger.warning("Final context is empty. LLM should state info is missing.") | |
| logger.info("Sending Final Query to LLM for Answer Generation...") | |
| sys.stdout.write("Waiting for final answer LLM stream...") | |
| sys.stdout.flush() | |
| stream_started = False | |
| start_time = time.time() | |
| first_chunk_received = False | |
| error_occurred = False | |
| print("\n--- LLM Answer ---") # Print header before stream starts | |
| try: | |
| stream = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=prompt_messages, | |
| temperature=0.3, # Lower temp for fact-based generation | |
| max_tokens=FINAL_ANSWER_MAX_TOKENS, | |
| stream=True | |
| ) | |
| # Print response chunks as they arrive | |
| for chunk in stream: | |
| if not stream_started: | |
| duration = time.time() - start_time | |
| sys.stdout.write(f"\rStream started ({duration:.2f}s). Receiving data...\n") | |
| sys.stdout.flush() | |
| stream_started = True | |
| content_piece = chunk.choices[0].delta.content | |
| if content_piece: | |
| sys.stdout.write(content_piece) | |
| sys.stdout.flush() | |
| if not first_chunk_received: first_chunk_received = True | |
| # --- Final Stream Cleanup --- | |
| if stream_started: | |
| if not first_chunk_received: | |
| sys.stdout.write("[No content received from LLM stream]\n") | |
| sys.stdout.flush() | |
| else: | |
| print() # Ensure final newline | |
| else: | |
| sys.stdout.write("\r" + " " * 80 + "\r") # Clear waiting message | |
| logger.error("LLM stream did not start.") | |
| error_occurred = True | |
| except Exception as e: | |
| handled = _handle_openai_error(e, start_time, "LLM final answer request") | |
| error_occurred = not handled | |
| if stream_started and first_chunk_received: print() # Newline after partial output on error | |
| return not error_occurred # True if successful | |
| def run_interactive(self): | |
| """Runs the main interactive Q&A loop for documentation.""" | |
| print("\n--- Documentation Q&A Agent ---") | |
| print(f"Model: {self.model_name} ({self.client.base_url})") | |
| print(f"Docs Path: {self.docs_path.resolve()}") | |
| print("Enter your question about the documentation (or type 'quit'):") | |
| while True: | |
| try: | |
| user_query = input("> ") | |
| if user_query.lower().strip() == 'quit': break | |
| if not user_query.strip(): continue | |
| start_time_turn = time.time() | |
| logger.info(f"User Query: '{user_query}'") | |
| # === RAG Pipeline Steps === | |
| # 1. Keyword Extraction | |
| initial_keywords = extract_keywords(user_query) | |
| if not initial_keywords: | |
| print("Could not extract keywords. Please rephrase.", file=sys.stderr) | |
| continue | |
| logger.info(f"Initial keywords: {initial_keywords}") | |
| # 2. Initial Retrieval | |
| logger.info("Performing initial documentation search...") | |
| context = self.search_files(initial_keywords, MAX_INITIAL_SNIPPETS) | |
| # 3. LLM Guidance (Optional Refinement) | |
| refined_keywords_str = self.ask_llm_for_guidance(user_query, context) | |
| # 4. Refined Retrieval (if guided) | |
| if refined_keywords_str: | |
| refined_keywords_list = extract_keywords(refined_keywords_str) | |
| if refined_keywords_list: | |
| logger.info("Performing refined documentation search...") | |
| context = self.search_files(refined_keywords_list, MAX_REFINED_SNIPPETS) | |
| else: | |
| logger.warning("LLM suggested refinement, but no keywords extracted. Using initial context.") | |
| # 5. Grounded Generation | |
| success = self.ask_llm_for_final_answer(user_query, context) | |
| # === Loop Cleanup === | |
| end_time_turn = time.time() | |
| print("\n-----------------------------") | |
| if not success: | |
| print("Error generating final answer.", file=sys.stderr) | |
| logger.info(f"Turn completed in {end_time_turn - start_time_turn:.2f} seconds.") | |
| print("\nEnter your next question (or type 'quit'):") | |
| except EOFError: print("\nExiting agent (EOF)."); break | |
| except KeyboardInterrupt: print("\nExiting agent (Interrupted)."); break | |
| print("\nDocumentation Q&A session finished.") | |
| # --- Main Execution --- | |
| def main(): | |
| """Parses args and runs the Documentation Q&A Agent.""" | |
| parser = argparse.ArgumentParser( | |
| description="Query technical documentation using RAG and an LLM.", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
| ) | |
| parser.add_argument( | |
| "-d", "--docs", type=Path, default=DEFAULT_DOCS_PATH, | |
| help="Path to the documentation source directory (containing .rst/.txt files)." | |
| ) | |
| parser.add_argument( | |
| "-u", "--url", default=DEFAULT_LLM_BASE_URL, | |
| help="Base URL of the OpenAI-compatible LLM API endpoint." | |
| ) | |
| parser.add_argument( | |
| "-k", "--key", default=os.getenv('OPENAI_API_KEY', 'not-needed'), | |
| help="API key for the LLM endpoint (uses OPENAI_API_KEY env var if set)." | |
| ) | |
| parser.add_argument( | |
| "-m", "--model", default=DEFAULT_MODEL_NAME, | |
| help="Model name to use for LLM API requests." | |
| ) | |
| parser.add_argument( | |
| "-v", "--verbose", action="store_true", | |
| help="Enable DEBUG level logging." | |
| ) | |
| args = parser.parse_args() | |
| # Setup logging level | |
| logging.getLogger().setLevel(logging.DEBUG if args.verbose else logging.INFO) | |
| logger.debug("Debug logging enabled.") | |
| if args.key == 'not-needed': | |
| logger.info("Using default placeholder API key. Set OPENAI_API_KEY or use --key if required.") | |
| try: | |
| # Create and run the agent | |
| agent = DocAgent( | |
| docs_path=args.docs, | |
| base_url=args.url, | |
| api_key=args.key, | |
| model_name=args.model | |
| ) | |
| agent.run_interactive() | |
| except (ValueError, RuntimeError, FileNotFoundError) as e: | |
| logger.error(f"Initialization/Configuration Error: {e}", exc_info=args.verbose) | |
| sys.exit(1) | |
| except Exception as e: | |
| logger.error(f"An unexpected critical error occurred: {e}", exc_info=True) | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment