Skip to content

Instantly share code, notes, and snippets.

@usernaamee
Created May 4, 2025 04:57
Show Gist options
  • Select an option

  • Save usernaamee/43bcf0b273e08c4ea873ce9ab21d67cc to your computer and use it in GitHub Desktop.

Select an option

Save usernaamee/43bcf0b273e08c4ea873ce9ab21d67cc to your computer and use it in GitHub Desktop.
RAG on Read the Docs
# -*- coding: utf-8 -*-
"""
doc_agent.py: A Python script implementing a Retrieval-Augmented Generation (RAG)
agent specifically designed for querying technical documentation (e.g., from
Read the Docs source files like .rst or .txt).
Purpose: To provide more accurate and context-aware answers to questions
about a specific software package or project than traditional keyword search,
by leveraging an LLM grounded with retrieved documentation snippets.
Features:
- Searches local documentation files (.rst, .txt) for relevant context.
- Uses keyword extraction and scoring, with optional LLM-guided refinement.
- Queries an LLM (local or remote OpenAI-compatible API) with retrieved
context to generate answers based *only* on the documentation.
- Streams LLM responses for better interactivity.
- Configurable via command-line arguments.
"""
import os
import openai # OpenAI client library
import argparse
import re
import sys
import time
import logging
from pathlib import Path
from typing import List, Optional, Set, Tuple # Improved typing clarity
# --- Configuration & Constants ---
# Setup basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) # Use a dedicated logger
# Default path for documentation source files (e.g., Sphinx source directory)
DEFAULT_DOCS_PATH = Path('./source')
# Default URL for the OpenAI-compatible LLM endpoint
DEFAULT_LLM_BASE_URL = 'http://localhost:8080/v1'
# Default model name (can be overridden)
DEFAULT_MODEL_NAME = 'local-model'
# --- RAG Parameters ---
# Max characters of documentation context to send to the LLM
MAX_CONTEXT_CHARS = 3800
# Max character length for a single documentation snippet (paragraph)
MAX_SNIPPET_LEN = 500
# Max snippets to retrieve in the *initial* search
MAX_INITIAL_SNIPPETS = 10
# Max snippets to retrieve in the *refined* search (after LLM guidance)
MAX_REFINED_SNIPPETS = 8
# Timeout for LLM API requests
REQUEST_TIMEOUT = 180.0
# Max tokens for the LLM guidance response (keep it short)
GUIDANCE_MAX_TOKENS = 100
# Max tokens for the final generated answer
FINAL_ANSWER_MAX_TOKENS = 1500
# --- Keyword Extraction ---
# Stop words to ignore in user queries (customize for your domain if needed)
STOP_WORDS = {
"a", "about", "above", "after", "again", "against", "all", "am", "an", "and",
"any", "are", "as", "at", "be", "because", "been", "before", "being",
"below", "between", "both", "but", "by", "can", "cannot", "could",
"did", "do", "does", "doing", "down", "during", "each", "few", "for",
"from", "further", "had", "has", "have", "having", "he", "her", "here",
"hers", "herself", "him", "himself", "his", "how", "i", "if", "in", "into",
"is", "it", "its", "itself", "let", "me", "more", "most", "my", "myself",
"no", "nor", "not", "of", "off", "on", "once", "only", "or", "other",
"ought", "our", "ours", "ourselves", "out", "over", "own", "same", "she",
"should", "so", "some", "such", "than", "that", "the", "their", "theirs",
"them", "themselves", "then", "there", "these", "they", "this", "those",
"through", "to", "too", "under", "until", "up", "very", "was", "we", "were",
"what", "when", "where", "which", "while", "who", "whom", "why", "with",
"would", "you", "your", "yours", "yourself", "yourselves",
# Common query words that are often noise for documentation search
"using", "get", "set", "configure", "run", "install", "find", "tell", "me", "about",
"what's", "whats", "how's", "hows", "show", "example", "examples"
}
# --- Helper Functions ---
def extract_keywords(query: str, existing_keywords: Optional[List[str]] = None) -> List[str]:
"""
Extracts potential keywords from a user query about software documentation.
Filters stop words and short words, combines with existing keywords.
"""
words = re.findall(r'\b\w+\b', query.lower())
new_keywords = {word for word in words if word not in STOP_WORDS and len(word) > 2}
combined = set(existing_keywords or []) | new_keywords
# Consider adding stemming or lemmatization here for more advanced matching
return list(combined)
def _handle_openai_error(e: Exception, start_time: float, context_msg: str = "LLM request") -> bool:
"""Handles common OpenAI API errors, logs them, and returns status."""
duration = time.time() - start_time
sys.stdout.write("\r" + " " * 80 + "\r") # Clear console line
sys.stdout.flush()
# Log specific error types for better diagnosis
if isinstance(e, openai.APITimeoutError):
logger.error(f"{context_msg} timed out after {duration:.2f}s.")
elif isinstance(e, openai.APIConnectionError):
logger.error(f"Connection error during {context_msg} ({duration:.2f}s): {e}")
elif isinstance(e, openai.RateLimitError):
logger.error(f"Rate limit exceeded during {context_msg} ({duration:.2f}s): {e}")
elif isinstance(e, openai.APIStatusError):
logger.error(f"API status error during {context_msg} ({duration:.2f}s): Status={e.status_code}, Response={e.response}")
elif isinstance(e, openai.APIError):
logger.error(f"Generic API error during {context_msg} ({duration:.2f}s): {e}")
else:
logger.error(f"Unexpected error during {context_msg} ({duration:.2f}s): {e}", exc_info=True)
return False # Unexpected failure
return True # Expected/handled API failure
# --- Document Agent Class ---
class DocAgent:
"""
Agent orchestrating the RAG process for querying documentation.
Searches local files, interacts with LLM for guidance and final answers.
"""
def __init__(self, docs_path: Path, base_url: str, api_key: str, model_name: str):
"""
Initializes the documentation Q&A agent.
Args:
docs_path: Path to the root directory containing documentation
source files (e.g., .rst, .txt).
base_url: The base URL of the OpenAI-compatible API endpoint.
api_key: The API key for the endpoint.
model_name: The name of the model to use.
"""
self.docs_path = docs_path
self.model_name = model_name
self.doc_files = self._find_doc_files() # Find documentation files
if not self.doc_files:
raise ValueError(f"No documentation files (.rst, .txt) found in {self.docs_path}. Check the --docs path.")
logger.info(f"Initializing OpenAI client: URL='{base_url}', Model='{model_name}'")
try:
self.client = openai.OpenAI(
base_url=base_url,
api_key=api_key,
timeout=REQUEST_TIMEOUT,
)
logger.info("OpenAI client initialized successfully.")
except Exception as e:
logger.error(f"Error initializing OpenAI client: {e}", exc_info=True)
raise RuntimeError(f"Failed to initialize OpenAI client. Check API URL ('{base_url}') and ensure the server is running.") from e
def _find_doc_files(self) -> List[Path]:
"""Recursively finds documentation files (.rst, .txt)"""
if not self.docs_path.is_dir():
logger.error(f"Documentation directory not found: {self.docs_path}")
return []
logger.info(f"Scanning for .rst and .txt files in '{self.docs_path}'...")
rst_files = list(self.docs_path.rglob('*.rst'))
txt_files = list(self.docs_path.rglob('*.txt'))
# Add more file types here if needed (e.g., .md)
doc_files = [f for f in rst_files + txt_files if f.is_file()]
logger.info(f"Found {len(doc_files)} documentation files.")
return doc_files
def _search_single_file(self, file_path: Path, query_keywords: Set[str]) -> List[Tuple[int, str]]:
"""Searches a single documentation file for keywords, returns scored snippets."""
scored_paras = []
try:
# Handle potential encoding issues in documentation files
try:
content = file_path.read_text(encoding='utf-8')
except UnicodeDecodeError:
content = file_path.read_text(encoding='latin-1') # Fallback
logger.warning(f"Used latin-1 fallback encoding for '{file_path.name}'")
# Split by paragraph (common in .rst/.txt docs). Consider more robust chunking for complex docs.
paragraphs = re.split(r'\n\s*\n+', content)
for para in paragraphs:
para_strip = para.strip()
# Basic filtering of paragraphs
if not para_strip or len(para_strip) < 15: continue # Skip very short lines/paragraphs
para_lower = para_strip.lower()
found_keywords = {kw for kw in query_keywords if kw in para_lower}
score = len(found_keywords)
if score > 0:
# Optional: Add proximity bonus if multiple keywords are close
if score > 1:
indices = [para_lower.find(kw) for kw in found_keywords]
valid_indices = [i for i in indices if i != -1]
if len(valid_indices) > 1:
span = max(valid_indices) - min(valid_indices)
if span < MAX_SNIPPET_LEN * 1.5: score += 1
# Extract snippet centered around the first keyword
try:
first_kw_index = min(idx for kw in found_keywords if (idx := para_lower.find(kw)) != -1)
first_kw = next(kw for kw in found_keywords if para_lower.find(kw) == first_kw_index)
start = max(0, first_kw_index - MAX_SNIPPET_LEN // 3)
end = min(len(para_strip), first_kw_index + len(first_kw) + (2 * MAX_SNIPPET_LEN // 3))
snippet = para_strip[start:end]
# Add ellipses if snippet is cut off
prefix = "..." if start > 0 else ""
suffix = "..." if end < len(para_strip) else ""
snippet = prefix + snippet.strip() + suffix
if len(snippet) > MAX_SNIPPET_LEN + 6: # Final trim
snippet = snippet[:MAX_SNIPPET_LEN+3] + "..."
scored_paras.append((score, snippet))
except (StopIteration, ValueError):
logger.debug(f"Minor issue extracting snippet from {file_path.name}")
# Fallback: use paragraph if short enough
if len(para_strip) <= MAX_SNIPPET_LEN: scored_paras.append((score, para_strip))
except Exception as e_snip:
logger.warning(f"Unexpected error during snippet extraction in {file_path.name}: {e_snip}")
except Exception as e_read:
logger.warning(f"Error reading/processing file {file_path}: {e_read}")
return scored_paras
def _jaccard_similarity(self, set1: Set[str], set2: Set[str]) -> float:
"""Calculates Jaccard similarity between two sets of words (for deduplication)."""
intersection = len(set1.intersection(set2))
union = len(set1.union(set2))
if union == 0: return 1.0 if not set1 and not set2 else 0.0
return intersection / union
def search_files(self, query_keywords: List[str], max_snippets_to_return: int) -> str:
"""
Searches documentation files, ranks paragraphs/snippets, deduplicates,
and returns a formatted context string.
"""
keyword_set = set(query_keywords)
logger.info(f"Searching {len(self.doc_files)} files for keywords: {keyword_set}")
all_scored_snippets: List[Tuple[int, Path, str]] = []
total_files = len(self.doc_files)
# --- File Iteration & Snippet Collection ---
for i, file_path in enumerate(self.doc_files):
# Progress indicator
if (i + 1) % 20 == 0 or (i + 1) == total_files:
progress = int(100 * (i + 1) / total_files)
sys.stdout.write(f"\rSearching... [{i+1}/{total_files} files ({progress}%)]")
sys.stdout.flush()
file_snippets = self._search_single_file(file_path, keyword_set)
for score, snippet in file_snippets:
all_scored_snippets.append((score, file_path, snippet))
sys.stdout.write("\r" + " " * 80 + "\r") # Clear progress line
sys.stdout.flush()
logger.info(f"Finished searching. Found {len(all_scored_snippets)} potential snippets.")
# --- Ranking & Deduplication ---
all_scored_snippets.sort(key=lambda x: (-x[0], x[1].as_posix())) # Sort by score desc, path asc
final_context = ""
total_chars = 0
snippet_count = 0
added_snippet_norm_hashes: Set[int] = set() # For exact duplicates
added_snippet_word_sets: List[Set[str]] = [] # For near-duplicates (Jaccard)
included_files: Set[Path] = set() # Track files mentioned in context
for score, file_path, snippet in all_scored_snippets:
if snippet_count >= max_snippets_to_return or total_chars >= MAX_CONTEXT_CHARS:
break # Stop if limits reached
# Normalize for robust duplicate checking
normalized_snippet = ' '.join(re.findall(r'\b\w+\b', snippet.lower()))
if not normalized_snippet: continue
# Check exact duplicate (hash)
norm_hash = hash(normalized_snippet)
if norm_hash in added_snippet_norm_hashes: continue
# Check near-duplicate (Jaccard)
current_word_set = set(normalized_snippet.split())
is_near_duplicate = any(self._jaccard_similarity(current_word_set, existing_set) > 0.8 for existing_set in added_snippet_word_sets)
if is_near_duplicate:
logger.debug(f"Skipping near-duplicate snippet from {file_path.name}")
continue
# --- Assemble Context ---
try:
relative_path = os.path.relpath(file_path, self.docs_path.parent)
except ValueError:
relative_path = file_path.as_posix() # Fallback path
file_marker = f"\n--- Context from: {relative_path} ---\n"
context_to_add = ""
if file_path not in included_files:
context_to_add += file_marker
included_files.add(file_path)
context_to_add += snippet.strip() + "\n"
# Add if within character limits
estimated_len = len(context_to_add)
if total_chars + estimated_len <= MAX_CONTEXT_CHARS:
final_context += context_to_add
total_chars += estimated_len
snippet_count += 1
added_snippet_norm_hashes.add(norm_hash)
added_snippet_word_sets.append(current_word_set)
else:
logger.debug(f"Next snippet exceeds context limit. Stopping.")
break
logger.info(f"Selected {snippet_count} unique snippets ({total_chars} chars) for context.")
return final_context.strip()
def ask_llm_for_guidance(self, user_query: str, context: str) -> Optional[str]:
"""
Asks LLM if context is sufficient or if search needs refinement
using more specific keywords from the documentation context.
"""
# --- System Prompt: Focus on Documentation Search Assistance ---
system_prompt = (
"You are an expert search assistant for software documentation. Analyze the user's QUESTION "
"and the provided CONTEXT snippets retrieved from the documentation.\n"
"Determine if the CONTEXT likely contains enough information to answer the QUESTION, "
"or if a more focused search using different keywords (like specific function names, parameters, config keys mentioned in the question but maybe missing in context) would be better.\n"
"Your response MUST be in one of the following two formats ONLY:\n"
"1. If context seems sufficient: CONTEXT_SUFFICIENT\n"
"2. If context is insufficient or needs refinement: SEARCH_FOR: keyword1 keyword2 keyword3\n"
"Provide 1 to 5 specific, relevant keywords if refinement is needed. Focus on technical terms.\n"
"Do NOT provide explanations or any other text."
)
prompt_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"CONTEXT:\n```\n{context or '[No context found]'}\n```\n\nQUESTION:\n{user_query}"}
]
logger.info("Asking LLM for Search Guidance...")
sys.stdout.write("Waiting for guidance LLM...")
sys.stdout.flush()
accumulated_content = ""
stream_started = False
start_time = time.time()
try:
stream = self.client.chat.completions.create(
model=self.model_name,
messages=prompt_messages,
temperature=0.1, # Low temp for structured output
max_tokens=GUIDANCE_MAX_TOKENS,
stream=True,
stop=["\n"] # Encourage single-line response
)
# Accumulate streamed response
for chunk in stream:
if not stream_started:
duration = time.time() - start_time
sys.stdout.write(f"\rGuidance stream started ({duration:.2f}s). Receiving... ")
sys.stdout.flush()
stream_started = True
content_piece = chunk.choices[0].delta.content
if content_piece:
accumulated_content += content_piece
sys.stdout.write(".")
sys.stdout.flush()
# --- Process Guidance Response ---
duration = time.time() - start_time
sys.stdout.write(f"\rGuidance stream finished ({duration:.2f}s). Processing...{' '*20}\n")
sys.stdout.flush()
if not accumulated_content:
logger.warning("LLM guidance stream provided no content.")
return None
cleaned_content = accumulated_content.strip()
logger.debug(f"LLM Raw Guidance Response: '{cleaned_content}'")
# Parse the structured response
if cleaned_content.startswith("SEARCH_FOR:"):
keywords_str = cleaned_content[len("SEARCH_FOR:"):].strip()
keywords_str = re.sub(r'[.,;!?]$', '', keywords_str).strip()
keywords_str = ' '.join(keywords_str.split()) # Normalize spaces
if keywords_str:
logger.info(f"LLM suggests refining search with keywords: '{keywords_str}'")
return keywords_str
else:
logger.warning("LLM guidance 'SEARCH_FOR:' but no keywords provided.")
return None
elif "CONTEXT_SUFFICIENT" in cleaned_content:
logger.info("LLM guidance indicates context is likely sufficient.")
return None
else:
logger.warning(f"LLM guidance response format unexpected: '{cleaned_content}'. Proceeding with initial context.")
return None
except Exception as e:
_handle_openai_error(e, start_time, "LLM guidance request")
return None # Indicate failure
def ask_llm_for_final_answer(self, user_query: str, context: str) -> bool:
"""
Queries LLM with final context to generate an answer grounded
*only* in the provided documentation snippets. Streams the response.
"""
# --- System Prompt: Grounding the LLM as a Documentation Specialist ---
system_prompt = (
"You are a specialist assistant knowledgeable about a specific software package, based *only* on the documentation provided.\n"
"Carefully analyze the user's QUESTION and the provided CONTEXT snippets from the documentation.\n"
"Answer the user's QUESTION accurately and concisely using *only* the information present in the CONTEXT.\n"
"If the CONTEXT does not contain information to answer the question or is empty, you MUST explicitly state that the information is not available in the provided snippets.\n"
"Do not make assumptions or use external knowledge.\n"
"Be direct. Do not include conversational filler like 'Based on the provided context...'.\n"
"If providing code examples from the context, ensure they are formatted correctly using markdown."
)
prompt_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"CONTEXT:\n```\n{context or '[No relevant context found after search]'}\n```\n\nQUESTION:\n{user_query}"}
]
if not context.strip():
logger.warning("Final context is empty. LLM should state info is missing.")
logger.info("Sending Final Query to LLM for Answer Generation...")
sys.stdout.write("Waiting for final answer LLM stream...")
sys.stdout.flush()
stream_started = False
start_time = time.time()
first_chunk_received = False
error_occurred = False
print("\n--- LLM Answer ---") # Print header before stream starts
try:
stream = self.client.chat.completions.create(
model=self.model_name,
messages=prompt_messages,
temperature=0.3, # Lower temp for fact-based generation
max_tokens=FINAL_ANSWER_MAX_TOKENS,
stream=True
)
# Print response chunks as they arrive
for chunk in stream:
if not stream_started:
duration = time.time() - start_time
sys.stdout.write(f"\rStream started ({duration:.2f}s). Receiving data...\n")
sys.stdout.flush()
stream_started = True
content_piece = chunk.choices[0].delta.content
if content_piece:
sys.stdout.write(content_piece)
sys.stdout.flush()
if not first_chunk_received: first_chunk_received = True
# --- Final Stream Cleanup ---
if stream_started:
if not first_chunk_received:
sys.stdout.write("[No content received from LLM stream]\n")
sys.stdout.flush()
else:
print() # Ensure final newline
else:
sys.stdout.write("\r" + " " * 80 + "\r") # Clear waiting message
logger.error("LLM stream did not start.")
error_occurred = True
except Exception as e:
handled = _handle_openai_error(e, start_time, "LLM final answer request")
error_occurred = not handled
if stream_started and first_chunk_received: print() # Newline after partial output on error
return not error_occurred # True if successful
def run_interactive(self):
"""Runs the main interactive Q&A loop for documentation."""
print("\n--- Documentation Q&A Agent ---")
print(f"Model: {self.model_name} ({self.client.base_url})")
print(f"Docs Path: {self.docs_path.resolve()}")
print("Enter your question about the documentation (or type 'quit'):")
while True:
try:
user_query = input("> ")
if user_query.lower().strip() == 'quit': break
if not user_query.strip(): continue
start_time_turn = time.time()
logger.info(f"User Query: '{user_query}'")
# === RAG Pipeline Steps ===
# 1. Keyword Extraction
initial_keywords = extract_keywords(user_query)
if not initial_keywords:
print("Could not extract keywords. Please rephrase.", file=sys.stderr)
continue
logger.info(f"Initial keywords: {initial_keywords}")
# 2. Initial Retrieval
logger.info("Performing initial documentation search...")
context = self.search_files(initial_keywords, MAX_INITIAL_SNIPPETS)
# 3. LLM Guidance (Optional Refinement)
refined_keywords_str = self.ask_llm_for_guidance(user_query, context)
# 4. Refined Retrieval (if guided)
if refined_keywords_str:
refined_keywords_list = extract_keywords(refined_keywords_str)
if refined_keywords_list:
logger.info("Performing refined documentation search...")
context = self.search_files(refined_keywords_list, MAX_REFINED_SNIPPETS)
else:
logger.warning("LLM suggested refinement, but no keywords extracted. Using initial context.")
# 5. Grounded Generation
success = self.ask_llm_for_final_answer(user_query, context)
# === Loop Cleanup ===
end_time_turn = time.time()
print("\n-----------------------------")
if not success:
print("Error generating final answer.", file=sys.stderr)
logger.info(f"Turn completed in {end_time_turn - start_time_turn:.2f} seconds.")
print("\nEnter your next question (or type 'quit'):")
except EOFError: print("\nExiting agent (EOF)."); break
except KeyboardInterrupt: print("\nExiting agent (Interrupted)."); break
print("\nDocumentation Q&A session finished.")
# --- Main Execution ---
def main():
"""Parses args and runs the Documentation Q&A Agent."""
parser = argparse.ArgumentParser(
description="Query technical documentation using RAG and an LLM.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-d", "--docs", type=Path, default=DEFAULT_DOCS_PATH,
help="Path to the documentation source directory (containing .rst/.txt files)."
)
parser.add_argument(
"-u", "--url", default=DEFAULT_LLM_BASE_URL,
help="Base URL of the OpenAI-compatible LLM API endpoint."
)
parser.add_argument(
"-k", "--key", default=os.getenv('OPENAI_API_KEY', 'not-needed'),
help="API key for the LLM endpoint (uses OPENAI_API_KEY env var if set)."
)
parser.add_argument(
"-m", "--model", default=DEFAULT_MODEL_NAME,
help="Model name to use for LLM API requests."
)
parser.add_argument(
"-v", "--verbose", action="store_true",
help="Enable DEBUG level logging."
)
args = parser.parse_args()
# Setup logging level
logging.getLogger().setLevel(logging.DEBUG if args.verbose else logging.INFO)
logger.debug("Debug logging enabled.")
if args.key == 'not-needed':
logger.info("Using default placeholder API key. Set OPENAI_API_KEY or use --key if required.")
try:
# Create and run the agent
agent = DocAgent(
docs_path=args.docs,
base_url=args.url,
api_key=args.key,
model_name=args.model
)
agent.run_interactive()
except (ValueError, RuntimeError, FileNotFoundError) as e:
logger.error(f"Initialization/Configuration Error: {e}", exc_info=args.verbose)
sys.exit(1)
except Exception as e:
logger.error(f"An unexpected critical error occurred: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment