Skip to content

Instantly share code, notes, and snippets.

@STHITAPRAJNAS
Created April 19, 2025 19:07
Show Gist options
  • Select an option

  • Save STHITAPRAJNAS/0f6d22388986fe245ccb9596e47d02d8 to your computer and use it in GitHub Desktop.

Select an option

Save STHITAPRAJNAS/0f6d22388986fe245ccb9596e47d02d8 to your computer and use it in GitHub Desktop.
autogen agents
# --- config.py ---
import os
from pydantic_settings import BaseSettings
from dotenv import load_dotenv
import datetime
load_dotenv()
class Settings(BaseSettings):
"""
Configuration settings for the enhanced application.
Reads values from environment variables.
"""
# AWS Bedrock Settings
AWS_REGION_NAME: str = os.getenv("AWS_REGION_NAME", "us-east-1")
# Consider a more powerful model for complex reasoning like Sonnet or Opus
BEDROCK_MODEL_ID: str = os.getenv("BEDROCK_MODEL_ID", "anthropic.claude-3-sonnet-20240229-v1:0")
BEDROCK_EMBEDDING_MODEL_ID: str = os.getenv("BEDROCK_EMBEDDING_MODEL_ID", "amazon.titan-embed-text-v1")
# PGVector (PostgreSQL) Settings - Unified Store
PG_HOST: str = os.getenv("PG_HOST", "localhost")
PG_PORT: int = int(os.getenv("PG_PORT", "5432"))
PG_USER: str = os.getenv("PG_USER", "user")
PG_PASSWORD: str = os.getenv("PG_PASSWORD", "password")
PG_DBNAME: str = os.getenv("PG_DBNAME", "knowledgebase")
PG_COLLECTION_NAME: str = os.getenv("PG_COLLECTION_NAME", "unified_embeddings") # LangChain uses collection name
PG_VECTOR_DIMENSION: int = int(os.getenv("PG_VECTOR_DIMENSION", "1536")) # Must match embedding model
# Databricks Settings
DATABRICKS_SERVER_HOSTNAME: str = os.getenv("DATABRICKS_SERVER_HOSTNAME", "")
DATABRICKS_HTTP_PATH: str = os.getenv("DATABRICKS_HTTP_PATH", "")
DATABRICKS_TOKEN: str = os.getenv("DATABRICKS_TOKEN", "")
# GraphQL API Settings (Example)
GRAPHQL_API_ENDPOINT: str = os.getenv("GRAPHQL_API_ENDPOINT", "http://localhost:4000/graphql") # Example endpoint
# Add API keys/auth headers if needed
# GRAPHQL_AUTH_HEADER: str = os.getenv("GRAPHQL_AUTH_HEADER", "")
# Permissions Configuration (Example: path to a mapping file or DB connection)
# This is highly dependent on your actual permission system implementation
PERMISSION_CONFIG_PATH: str = os.getenv("PERMISSION_CONFIG_PATH", "permissions.json") # Example path
# Autogen Settings
AUTOGEN_TIMEOUT: int = 180 # Increased timeout for potentially longer chains
# API Settings
API_TITLE: str = "Enhanced Conversational Analytics Chatbot"
API_VERSION: str = "0.2.0"
CURRENT_DATE: str = datetime.date.today().isoformat()
# LangChain Retriever Settings
RETRIEVER_SEARCH_TYPE: str = "similarity" # or "mmr"
RETRIEVER_K: int = 4 # Number of documents to retrieve
class Config:
case_sensitive = False
env_file = '.env'
env_file_encoding = 'utf-8'
settings = Settings()
# --- tools.py ---
import boto3
import psycopg2
from psycopg2.extras import RealDictCursor
from databricks import sql as databricks_sql
import logging
import json
import requests # For GraphQL client
from config import settings
# LangChain specific imports
try:
from langchain_community.vectorstores import PGVector
from langchain_aws import BedrockEmbeddings # Use new package structure
from langchain_core.documents import Document
except ImportError:
raise ImportError("LangChain packages not found. Please install 'langchain-community', 'langchain-aws', 'langchain-core'.")
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Bedrock Client (Shared by Embeddings and potentially LLM if not using BedrockProvider) ---
_bedrock_client = None
def get_bedrock_client():
global _bedrock_client
if _bedrock_client is None:
try:
_bedrock_client = boto3.client(
service_name='bedrock-runtime', # Use runtime for invoke, embeddings
region_name=settings.AWS_REGION_NAME
)
logger.info(f"Bedrock client initialized successfully in region {settings.AWS_REGION_NAME}.")
except Exception as e:
logger.error(f"Error initializing Bedrock client: {e}", exc_info=True)
return None
return _bedrock_client
# --- LangChain Components Initialization ---
_langchain_retriever = None
def get_langchain_retriever():
"""Initializes and returns a LangChain PGVector retriever."""
global _langchain_retriever
if _langchain_retriever is None:
try:
logger.info("Initializing LangChain components...")
bedrock_client = get_bedrock_client()
if not bedrock_client:
raise ValueError("Bedrock client failed to initialize.")
# 1. Initialize Bedrock Embeddings
embeddings = BedrockEmbeddings(
client=bedrock_client,
model_id=settings.BEDROCK_EMBEDDING_MODEL_ID,
# region_name=settings.AWS_REGION_NAME # Often inferred from client
)
logger.info(f"LangChain BedrockEmbeddings initialized with model: {settings.BEDROCK_EMBEDDING_MODEL_ID}")
# 2. Define PGVector Connection String
# Ensure PGVector is installed: pip install pgvector
connection_string = PGVector.connection_string_from_db_params(
driver="psycopg2", # Ensure psycopg2 is installed
host=settings.PG_HOST,
port=settings.PG_PORT,
database=settings.PG_DBNAME,
user=settings.PG_USER,
password=settings.PG_PASSWORD,
)
logger.info(f"PGVector connection string prepared for collection: {settings.PG_COLLECTION_NAME}")
# 3. Initialize PGVector Vector Store
# Assumes the table/collection already exists and is populated
vector_store = PGVector(
connection_string=connection_string,
embedding_function=embeddings,
collection_name=settings.PG_COLLECTION_NAME,
# use_jsonb=True # If using JSONB for metadata storage
)
logger.info("PGVector vector store initialized.")
# 4. Create Retriever
_langchain_retriever = vector_store.as_retriever(
search_type=settings.RETRIEVER_SEARCH_TYPE,
search_kwargs={"k": settings.RETRIEVER_K}
)
logger.info(f"LangChain retriever created with search_type='{settings.RETRIEVER_SEARCH_TYPE}' and k={settings.RETRIEVER_K}.")
except Exception as e:
logger.error(f"Failed to initialize LangChain retriever: {e}", exc_info=True)
_langchain_retriever = None # Ensure it's None on failure
# Depending on desired behavior, could raise the exception
# raise e
return _langchain_retriever
# --- Tool: LangChain Context Retriever ---
def retrieve_context(query: str) -> str:
"""
Retrieves relevant context (Confluence docs or DB schemas) from the unified
knowledge base in PGVector using a LangChain retriever based on the user query.
"""
logger.info(f"Received context retrieval request for query: {query}")
retriever = get_langchain_retriever()
if not retriever:
return "Error: Context retriever is not available due to initialization failure."
try:
# Invoke the retriever
# Langchain >= 0.1.0 uses invoke
retrieved_docs: list[Document] = retriever.invoke(query)
if not retrieved_docs:
logger.warning(f"No relevant context found in knowledge base for query: {query}")
return "No relevant context found in the knowledge base matching your query."
# Format results
formatted_context = "Retrieved Context:\n\n" + "\n\n---\n\n".join([
f"Source: {doc.metadata.get('source', 'Unknown')}\n" # Example: Include source if stored in metadata
# f"Type: {doc.metadata.get('type', 'Unknown')}\n" # Example: Include type (doc/schema) if stored
f"Content: {doc.page_content}"
for doc in retrieved_docs
])
logger.info(f"Retrieved {len(retrieved_docs)} context documents for query: {query}")
return formatted_context
except Exception as e:
logger.error(f"Error during context retrieval for query '{query}': {e}", exc_info=True)
return f"An unexpected error occurred during context retrieval."
# --- Tool: Databricks Permission Check ---
def check_databricks_permissions(user_id: str, required_tables: list[str]) -> bool:
"""
Checks if the given user_id has permission to access all the required Databricks tables.
**Placeholder Implementation:** This needs to be replaced with actual logic
to query your permission system (e.g., check roles in a database, lookup in a file).
"""
logger.info(f"Checking Databricks permissions for user '{user_id}' on tables: {required_tables}")
if not user_id or not required_tables:
logger.warning("Permission check called with missing user_id or required_tables.")
return False # Or raise an error
# --- Placeholder Logic ---
# Replace this section with your actual permission checking mechanism.
# Example: Load permissions from a JSON file specified in config
try:
with open(settings.PERMISSION_CONFIG_PATH, 'r') as f:
# Example format: {"user_roles": {"user1": "admin", "user2": "viewer"}, "role_permissions": {"admin": ["*"], "viewer": ["sales_table", "marketing_data"]}}
permissions_data = json.load(f)
user_roles = permissions_data.get("user_roles", {})
role_permissions = permissions_data.get("role_permissions", {})
user_role = user_roles.get(user_id)
if not user_role:
logger.warning(f"User '{user_id}' not found in permission system.")
return False
allowed_tables = role_permissions.get(user_role, [])
if "*" in allowed_tables: # Wildcard for all tables
logger.info(f"User '{user_id}' (role: {user_role}) has wildcard access. Permissions granted.")
return True
# Check if all required tables are in the allowed list
for table in required_tables:
# Normalize table names if necessary (e.g., remove schema prefix if stored differently)
normalized_table = table.lower()
if normalized_table not in [t.lower() for t in allowed_tables]:
logger.warning(f"User '{user_id}' (role: {user_role}) denied access to table: '{table}'. Allowed: {allowed_tables}")
return False
logger.info(f"User '{user_id}' (role: {user_role}) has permissions for all required tables: {required_tables}. Permissions granted.")
return True
except FileNotFoundError:
logger.error(f"Permission configuration file not found at: {settings.PERMISSION_CONFIG_PATH}")
return False # Default to deny if config is missing
except Exception as e:
logger.error(f"Error checking permissions for user '{user_id}': {e}", exc_info=True)
return False # Default to deny on error
# --- End Placeholder Logic ---
# --- Tool: Databricks SQL Execution ---
# (Largely unchanged, but now called conditionally after permission check)
def execute_databricks_sql(sql_query: str) -> str:
"""
Executes a **read-only** SQL query against the configured Databricks SQL Warehouse.
Prevents execution of potentially harmful SQL commands.
Should only be called after check_databricks_permissions returns True for the target tables.
"""
logger.info(f"Received Databricks SQL query execution request: {sql_query}")
# **Security Enhancement: Basic check for disallowed keywords**
# (Keep this check as a safeguard)
disallowed_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "GRANT", "REVOKE", "TRUNCATE"]
# Improve check to handle comments and variations
normalized_query = ' '.join(sql_query.upper().split())
if any(keyword in normalized_query for keyword in disallowed_keywords):
logger.warning(f"Disallowed SQL keyword detected in query: {sql_query}")
return "Error: Only read-only SELECT queries are allowed."
if not normalized_query.startswith("SELECT"):
logger.warning(f"Query does not start with SELECT: {sql_query}")
return "Error: Only SELECT queries are allowed."
# Check configuration
if not all([settings.DATABRICKS_SERVER_HOSTNAME, settings.DATABRICKS_HTTP_PATH, settings.DATABRICKS_TOKEN]):
logger.error("Databricks connection details are missing.")
return "Error: Databricks connection details not configured."
connection = None
try:
connection = databricks_sql.connect(
server_hostname=settings.DATABRICKS_SERVER_HOSTNAME,
http_path=settings.DATABRICKS_HTTP_PATH,
access_token=settings.DATABRICKS_TOKEN
)
cursor = connection.cursor()
logger.info(f"Connected to Databricks SQL Warehouse {settings.DATABRICKS_SERVER_HOSTNAME}.")
cursor.execute(sql_query)
results = cursor.fetchall()
logger.info(f"Successfully executed Databricks SQL query. Fetched {len(results)} rows.")
if not results:
return "Query executed successfully and returned no results."
column_names = [desc[0] for desc in cursor.description]
max_rows_to_return = 50
results_to_return = results[:max_rows_to_return]
header = " | ".join(column_names)
separator = "-|-".join(["-" * len(name) for name in column_names])
rows = [" | ".join([str(item) for item in row]) for row in results_to_return]
formatted_results = f"Query Results:\n{header}\n{separator}\n" + "\n".join(rows)
if len(results) > max_rows_to_return:
formatted_results += f"\n\n(Result truncated to first {max_rows_to_return} rows. Total rows: {len(results)})"
cursor.close()
return formatted_results
except databricks_sql.exc.Error as e:
logger.error(f"Databricks SQL Error executing query on {settings.DATABRICKS_SERVER_HOSTNAME}: {e}", exc_info=True)
error_message = f"Error executing Databricks SQL query: {getattr(e, 'message', str(e))}"
# Add specific error checks if needed
return error_message
except Exception as e:
logger.error(f"Unexpected error executing Databricks SQL: {e}", exc_info=True)
return f"An unexpected error occurred while executing the Databricks query."
finally:
if connection:
connection.close()
logger.info("Databricks connection closed.")
# --- Tool: GraphQL Execution ---
def execute_graphql_query(graphql_query: str, api_endpoint: str | None = None) -> str:
"""
Executes a GraphQL query against the specified API endpoint.
Uses the endpoint from config if none is provided.
"""
target_endpoint = api_endpoint or settings.GRAPHQL_API_ENDPOINT
logger.info(f"Executing GraphQL query against endpoint: {target_endpoint}")
logger.debug(f"GraphQL Query: {graphql_query}") # Log query only at debug level
if not target_endpoint:
return "Error: GraphQL API endpoint is not configured."
headers = {
'Content-Type': 'application/json',
# Add Authentication headers if needed from config
# 'Authorization': f'Bearer {settings.GRAPHQL_AUTH_TOKEN}'
}
# Add auth header if configured
# if settings.GRAPHQL_AUTH_HEADER:
# auth_parts = settings.GRAPHQL_AUTH_HEADER.split(":", 1)
# if len(auth_parts) == 2:
# headers[auth_parts[0]] = auth_parts[1]
try:
response = requests.post(
target_endpoint,
headers=headers,
json={'query': graphql_query} # Standard way to send GraphQL query
)
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
response_data = response.json()
if "errors" in response_data:
logger.warning(f"GraphQL query returned errors: {response_data['errors']}")
return f"Error executing GraphQL query: {json.dumps(response_data['errors'])}"
else:
logger.info("GraphQL query executed successfully.")
# Return the data part, formatted as JSON string
return f"GraphQL Query Result:\n{json.dumps(response_data.get('data', {}), indent=2)}"
except requests.exceptions.RequestException as e:
logger.error(f"HTTP Error executing GraphQL query to {target_endpoint}: {e}", exc_info=True)
return f"Error connecting to GraphQL endpoint: {e}"
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON response from GraphQL endpoint {target_endpoint}: {e}", exc_info=True)
return f"Received invalid JSON response from GraphQL endpoint."
except Exception as e:
logger.error(f"Unexpected error executing GraphQL query: {e}", exc_info=True)
return f"An unexpected error occurred while executing the GraphQL query."
# --- agents.py ---
import autogen
from autogen.agentchat.contrib.llm_provider import BedrockProvider
import logging
import re # For extracting table names
from config import settings
# Import all necessary tools
from tools import (
retrieve_context,
check_databricks_permissions,
execute_databricks_sql,
execute_graphql_query
)
logger = logging.getLogger(__name__)
# --- LLM Configuration ---
# Reusing BedrockProvider setup from previous version
# Ensure the selected Bedrock model (e.g., Claude 3 Sonnet) is strong at reasoning and tool use.
llm_config = None
try:
model_kwargs = {"temperature": 0.1, "max_tokens": 3000} # Increased max_tokens potentially needed
bedrock_provider = BedrockProvider(
region_name=settings.AWS_REGION_NAME,
model=settings.BEDROCK_MODEL_ID,
model_kwargs=model_kwargs
)
llm_config = {
"config_list": [{"model": settings.BEDROCK_MODEL_ID, "provider": bedrock_provider}],
"cache_seed": None, # Disable cache for dynamic behavior, or use 42 for testing
"timeout": settings.AUTOGEN_TIMEOUT,
}
logger.info(f"LLM Config prepared for Bedrock model: {settings.BEDROCK_MODEL_ID} via BedrockProvider.")
except Exception as e:
logger.error(f"Error setting up BedrockProvider: {e}", exc_info=True)
# --- Helper Function to Extract Table Names ---
def extract_table_names_from_sql(sql: str) -> list[str]:
"""
Simple heuristic to extract potential table names from a SQL query.
Looks for patterns like FROM table_name, JOIN table_name.
This is basic and might need improvement for complex queries with aliases, CTEs, etc.
"""
# Normalize whitespace and case
sql_normalized = ' '.join(sql.lower().split())
# Regex to find words following FROM or JOIN clauses
# This is a simplified regex and might capture aliases or subqueries incorrectly.
# A proper SQL parser would be more robust but adds complexity.
potential_tables = re.findall(r'(?:from|join)\s+([\w\.]+)', sql_normalized)
# Further filtering might be needed (e.g., remove common keywords if captured)
# Handle schema.table format if necessary
table_names = [name.split('.')[-1] for name in potential_tables] # Get last part if schema.table
return list(set(table_names)) # Return unique names
# --- Assistant Agent (Enhanced Orchestrator) ---
# This agent now handles a more complex workflow including context retrieval,
# permission checks, and multiple execution tools (SQL, GraphQL).
assistant_system_message = f"""You are 'NexusBot', an advanced AI assistant for accessing information and analytics.
Today's date is {settings.CURRENT_DATE}.
Your capabilities involve using the following tools:
1. `retrieve_context(query: str)`: Fetches relevant documents from a knowledge base containing Confluence articles AND database/API schema information. Use this first for almost all queries to get context.
2. `check_databricks_permissions(user_id: str, required_tables: list[str])`: Checks if the user has permission for specific Databricks tables BEFORE executing SQL. Requires the user's ID and a list of table names involved in the SQL query.
3. `execute_databricks_sql(sql_query: str)`: Executes a **read-only** SELECT SQL query against Databricks. Only use this AFTER `check_databricks_permissions` returns True.
4. `execute_graphql_query(graphql_query: str, api_endpoint: str | None = None)`: Executes a GraphQL query against a specified API endpoint (defaults to {settings.GRAPHQL_API_ENDPOINT}).
**Your Core Workflow:**
1. **Understand Query & Retrieve Context:** Analyze the user's query (`user_query`) and the user's ID (`user_id` will be provided in the initial message). Call `retrieve_context` with the `user_query` to get relevant information (could be Confluence docs, DB table schemas, GraphQL API info, etc.).
2. **Determine Intent & Plan:** Based on the `user_query` and the retrieved context:
* **General Knowledge:** If the query seems answerable from the retrieved Confluence-like documents, synthesize the answer directly.
* **Databricks Analytics (SQL):**
a. Use the retrieved DB schema context to formulate a precise, read-only SELECT SQL query.
b. **Extract Table Names:** Identify the table names involved in your generated SQL query.
c. **Check Permissions:** Call `check_databricks_permissions` with the `user_id` and the extracted table names.
d. **Execute SQL (If Permitted):** If permission check passes (returns True), call `execute_databricks_sql` with your generated SQL query. If permission check fails, inform the user they lack access.
* **API Interaction (GraphQL):**
a. Use the retrieved API schema context (if available) or general knowledge to formulate a GraphQL query.
b. Call `execute_graphql_query` with the generated query. (Note: Add permission checks here too if the API requires user-specific access control).
* **Ambiguous/Clarification:** If the query is unclear, lacks details (e.g., needs specific table names, date ranges, API parameters), or if needed context/schema wasn't retrieved, ASK the user clarifying questions. Do NOT guess table names, column names, or API structures.
3. **Synthesize & Respond:** Based on the results from the tools (or direct synthesis for knowledge questions), provide a clear, concise answer to the user. If a tool returns an error, report it clearly.
**CRITICAL Considerations:**
* **Always retrieve context first** using `retrieve_context` unless the query is trivial conversation.
* **Never execute SQL without checking permissions first.** Extract table names accurately from your *intended* SQL before calling the check.
* **Generate only SELECT SQL queries.** No INSERT, UPDATE, DELETE, etc.
* **Be precise.** Ask for clarification if the user's request is vague. Do not invent table/column/API details.
* The user's ID will be provided in the format "User ID: [user_id]". Extract this ID for permission checks.
**Example SQL Flow:**
- User: "User ID: user123. What are the top 5 products by sales in Q1 from the `fact_sales` table?"
- Assistant: (Calls `retrieve_context` with the query) -> (Gets schema for `fact_sales`) -> (Generates SQL: `SELECT product_name, SUM(sales_amount) AS total_sales FROM fact_sales WHERE quarter = 'Q1' GROUP BY product_name ORDER BY total_sales DESC LIMIT 5`) -> (Extracts table: `['fact_sales']`) -> (Calls `check_databricks_permissions` with `user_id='user123'`, `required_tables=['fact_sales']`) -> (Gets True) -> (Calls `execute_databricks_sql` with the generated SQL) -> (Gets results) -> (Synthesizes response for user).
**Example Permission Denied Flow:**
- User: "User ID: user456. Show me everything in `sensitive_hr_data`."
- Assistant: (Calls `retrieve_context`) -> (Gets schema) -> (Generates SQL: `SELECT * FROM sensitive_hr_data`) -> (Extracts table: `['sensitive_hr_data']`) -> (Calls `check_databricks_permissions` with `user_id='user456'`, `required_tables=['sensitive_hr_data']`) -> (Gets False) -> (Responds: "Sorry, you do not have permission to access the `sensitive_hr_data` table.")
Respond helpfully and accurately based on the defined workflow and tool outputs.
"""
if llm_config:
assistant = autogen.AssistantAgent(
name="NexusBot_Assistant",
llm_config=llm_config,
system_message=assistant_system_message,
)
logger.info("Enhanced Assistant Agent created.")
else:
assistant = None
logger.error("Assistant Agent could not be created because LLM config failed.")
# --- User Proxy Agent (Executor) ---
# Updated to include all tools in the function map
user_proxy = autogen.UserProxyAgent(
name="User_Proxy_Executor",
human_input_mode="NEVER",
max_consecutive_auto_reply=8, # Increased slightly for potentially longer chains (retrieve->check->execute)
is_termination_msg=lambda x: isinstance(x, dict) and "TERMINATE" in x.get("content", "").upper(),
code_execution_config=False,
function_map={
"retrieve_context": retrieve_context,
"check_databricks_permissions": check_databricks_permissions,
"execute_databricks_sql": execute_databricks_sql,
"execute_graphql_query": execute_graphql_query,
# Add the helper function here if needed by the assistant directly,
# though it's better if the assistant performs extraction internally
# "extract_table_names_from_sql": extract_table_names_from_sql
}
)
logger.info("User Proxy Agent created with updated tool function map.")
# --- Register Functions with Assistant ---
if assistant:
# Register all functions the User Proxy can execute
assistant.register_function(
function_map={
"retrieve_context": retrieve_context,
"check_databricks_permissions": check_databricks_permissions,
"execute_databricks_sql": execute_databricks_sql,
"execute_graphql_query": execute_graphql_query,
# "extract_table_names_from_sql": extract_table_names_from_sql # Only if assistant needs to call it
}
)
logger.info("Registered all tool functions with the Assistant Agent.")
# --- main.py ---
from fastapi import FastAPI, HTTPException, Body, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
import autogen
import asyncio
import logging
from contextlib import asynccontextmanager
import datetime
import re # For extracting user ID
# Import agents and settings
from config import settings
from agents import user_proxy, assistant, llm_config # Import configured agents
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Pydantic Models ---
class ChatRequest(BaseModel):
query: str = Field(..., description="The user's query for the chatbot.", min_length=1)
user_id: str = Field(..., description="User identifier required for permission checks.")
class ChatResponse(BaseModel):
response: str = Field(..., description="The chatbot's response to the query.")
tool_used: list[str] = Field([], description="List of tools used by the agent during processing.") # Changed to list
error: str | None = Field(None, description="Error message if processing failed.")
# --- Application Lifecycle (similar to previous, checking new components) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Enhanced application startup sequence initiated...")
start_time = datetime.datetime.now()
# Check critical configurations
llm_ok = llm_config is not None and assistant is not None and user_proxy is not None
retriever_ok = tools.get_langchain_retriever() is not None # Trigger retriever init early
if not llm_ok:
logger.critical("LLM Configuration or Agent initialization failed.")
if not retriever_ok:
logger.critical("LangChain Retriever initialization failed.")
app_ready = llm_ok and retriever_ok
if not app_ready:
logger.critical("Chat service will be unavailable due to initialization failures.")
# Log key settings
logger.info(f"Chatbot API '{settings.API_TITLE}' v{settings.API_VERSION} starting.")
logger.info(f"Using Bedrock Model: {settings.BEDROCK_MODEL_ID}")
logger.info(f"Knowledge Base: LangChain PGVector on {settings.PG_HOST}/{settings.PG_DBNAME}, Collection: {settings.PG_COLLECTION_NAME}")
logger.info(f"SQL Tool: Databricks on {settings.DATABRICKS_SERVER_HOSTNAME}")
logger.info(f"GraphQL Tool Endpoint: {settings.GRAPHQL_API_ENDPOINT}")
logger.info(f"Permission Config: {settings.PERMISSION_CONFIG_PATH}")
logger.info(f"Retriever Status: {'OK' if retriever_ok else 'Failed'}")
logger.info(f"Agent Status: {'OK' if llm_ok else 'Failed'}")
end_time = datetime.datetime.now()
logger.info(f"Application startup complete in {(end_time - start_time).total_seconds():.2f} seconds. Ready: {app_ready}")
yield
logger.info("Application shutdown sequence initiated...")
# Cleanup resources if needed
logger.info("Application shutdown complete.")
# --- FastAPI App ---
app = FastAPI(
title=settings.API_TITLE,
version=settings.API_VERSION,
description="Enhanced conversational analytics chatbot with RAG, SQL (w/ Permissions), and GraphQL capabilities.",
lifespan=lifespan,
)
# --- Error Handling (similar to previous) ---
@app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=ChatResponse(response="", error=f"Internal server error: {type(exc).__name__}").model_dump(exclude_none=True),
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
logger.warning(f"HTTP Exception: Status {exc.status_code}, Detail: {exc.detail}")
return JSONResponse(
status_code=exc.status_code,
content=ChatResponse(response="", error=str(exc.detail)).model_dump(exclude_none=True),
)
# --- Chat Endpoint ---
@app.post("/chat",
response_model=ChatResponse,
summary="Process complex user queries",
description="Handles user queries involving RAG, SQL (with permissions), and GraphQL via AutoGen agents.",
tags=["Chatbot"])
async def chat_endpoint(request: ChatRequest = Body(...)):
"""
Handles incoming chat requests, injecting user_id for agent use.
"""
logger.info(f"Received chat request from user '{request.user_id}'. Query: '{request.query[:100]}...'")
# Check readiness (LLM, Agents, Retriever)
if not llm_config or not assistant or not user_proxy or tools.get_langchain_retriever() is None:
logger.error("Chat service unavailable due to initialization failure.")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="The chat service is currently unavailable. Please try again later."
)
try:
# Format initial message to include user_id for the agent
initial_message = f"User ID: {request.user_id}. User Query: {request.query}"
logger.debug(f"Formatted initial message for agent: {initial_message}")
# Initiate chat
chat_result = await asyncio.to_thread(
user_proxy.initiate_chat,
recipient=assistant,
message=initial_message,
max_turns=10, # May need more turns for complex flows (retrieve->check->execute)
)
# --- Process Chat Result ---
final_response = "Sorry, I wasn't able to generate a response."
tools_used_list = []
error_message = None
if chat_result:
if chat_result.error:
error_message = f"Agent chat failed: {chat_result.error}"
logger.error(error_message)
if chat_result.chat_history and isinstance(chat_result.chat_history, list):
# Extract last assistant message
for msg in reversed(chat_result.chat_history):
if msg.get('role') == 'assistant':
final_response = msg.get('content', final_response).strip()
break
# Heuristic to find all tools used from history
history_str = str(chat_result.chat_history).lower()
possible_tools = [
"retrieve_context",
"check_databricks_permissions",
"execute_databricks_sql",
"execute_graphql_query"
]
# Look for function call patterns (adapt based on actual output format)
for tool in possible_tools:
if f"'function_call': {{'name': '{tool.lower()}'" in history_str or \
f"'tool_calls': [{{'function': {{'name': '{tool.lower()}'" in history_str:
tools_used_list.append(tool)
if error_message:
logger.error(f"Chat for user '{request.user_id}' failed. Error: {error_message}")
final_response = f"I encountered an error: {error_message}"
else:
logger.info(f"Chat for user '{request.user_id}' completed. Tools used: {tools_used_list}. Response: '{final_response[:100]}...'")
else:
logger.warning(f"Chat initiation for user '{request.user_id}' returned no result.")
error_message = "Chat returned no result."
final_response = "Sorry, something went wrong."
return ChatResponse(response=final_response, tool_used=list(set(tools_used_list)), error=error_message) # Return unique tools used
except Exception as e:
logger.error(f"Unexpected error during chat processing for user '{request.user_id}': {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unexpected error: {type(e).__name__}"
)
# --- Health Check Endpoint (Updated) ---
@app.get("/health", response_model=dict, tags=["Management"])
async def health_check():
"""Basic health check including component status."""
llm_ok = llm_config is not None and assistant is not None and user_proxy is not None
retriever_ok = tools.get_langchain_retriever() is not None # Checks if retriever initialized ok
status_overall = "ok" if llm_ok and retriever_ok else "error"
return {
"status": status_overall,
"timestamp": datetime.datetime.now().isoformat(),
"service": settings.API_TITLE,
"version": settings.API_VERSION,
"components": {
"llm_service": {"status": "ok" if llm_ok else "error", "model_id": settings.BEDROCK_MODEL_ID if llm_ok else None},
"knowledge_retriever": {"status": "ok" if retriever_ok else "error", "type": "LangChain PGVector"},
"databricks_connection": {"status": "configured" if all([settings.DATABRICKS_SERVER_HOSTNAME, settings.DATABRICKS_HTTP_PATH, settings.DATABRICKS_TOKEN]) else "missing_config"},
"graphql_connection": {"status": "configured" if settings.GRAPHQL_API_ENDPOINT else "missing_config"},
"permission_system": {"status": "configured", "path": settings.PERMISSION_CONFIG_PATH} # Basic config check
}
}
# --- requirements.txt ---
"""
# Core FastAPI and Server
fastapi>=0.100.0
uvicorn[standard]>=0.20.0
# Configuration and Environment
python-dotenv>=1.0.0
pydantic-settings>=2.0.0
# AutoGen and LLM Integration
pyautogen>=0.2.18 # Or check for latest stable version compatible with providers
# AWS SDK
boto3>=1.28.0
# LangChain Core and Integrations
langchain-core>=0.1.0
langchain-community>=0.0.20 # For PGVector
langchain-aws>=0.1.0 # For BedrockEmbeddings
# Database Connectors
psycopg2-binary>=2.9.0 # Or psycopg if building from source
pgvector>=0.2.0 # Ensure compatibility with psycopg2
# Databricks Connector
databricks-sql-connector>=2.9.0
# HTTP Client (for GraphQL)
requests>=2.28.0 # Or httpx for async requests if needed later
# Regular Expressions (Built-in)
# JSON (Built-in)
"""
# --- .env ---
"""
# --- Sample .env file for Enhanced Chatbot ---
# ** IMPORTANT: Never commit secrets to version control! Use secrets management. **
# AWS Credentials (Best Practice: Use IAM roles or ~/.aws/credentials)
# AWS_ACCESS_KEY_ID=...
# AWS_SECRET_ACCESS_KEY=...
# AWS_SESSION_TOKEN=... # Optional
# AWS Bedrock Settings
AWS_REGION_NAME="us-east-1"
BEDROCK_MODEL_ID="anthropic.claude-3-sonnet-20240229-v1:0" # Recommend Sonnet or Opus for complex reasoning
BEDROCK_EMBEDDING_MODEL_ID="amazon.titan-embed-text-v1"
# PGVector (PostgreSQL) Settings - Unified Store
PG_HOST="your_pg_instance_endpoint"
PG_PORT="5432"
PG_USER="your_db_user"
PG_PASSWORD='your_db_password'
PG_DBNAME="your_knowledgebase_db"
PG_COLLECTION_NAME="unified_embeddings" # LangChain collection name (maps to table)
PG_VECTOR_DIMENSION="1536" # MUST match BEDROCK_EMBEDDING_MODEL_ID
# Databricks Settings
DATABRICKS_SERVER_HOSTNAME="your_workspace_id.cloud.databricks.com"
DATABRICKS_HTTP_PATH="/sql/1.0/warehouses/your_sql_warehouse_id"
DATABRICKS_TOKEN="dapi_your_databricks_token"
# GraphQL API Settings
GRAPHQL_API_ENDPOINT="http://your_graphql_api.com/graphql"
# GRAPHQL_AUTH_HEADER="Authorization:Bearer your_api_token" # Example if auth needed
# Permissions Configuration (Example: Path to a local JSON file)
PERMISSION_CONFIG_PATH="permissions.json"
# Autogen Settings
# AUTOGEN_TIMEOUT=180
"""
# --- permissions.json (Example Structure) ---
"""
{
"user_roles": {
"user123": "admin",
"user456": "viewer",
"analyst007": "analyst"
},
"role_permissions": {
"admin": [
"*"
],
"analyst": [
"fact_sales",
"dim_product",
"user_activity",
"marketing_campaigns"
],
"viewer": [
"dim_product",
"marketing_campaigns"
]
}
}
"""
# --- README.md (Instructions - Enhanced) ---
"""
# Enhanced Conversational Analytics Chatbot
This project implements an advanced chatbot using AutoGen, Bedrock, LangChain, PGVector, Databricks, and FastAPI. It integrates Retrieval-Augmented Generation (RAG) from a unified knowledge base, permission-aware Databricks SQL execution, and GraphQL API interaction.
## Features
* **Sophisticated Agentic Workflow:** Uses AutoGen Assistant and UserProxy agents for complex task orchestration.
* **Powerful LLM:** Leverages AWS Bedrock models (configurable, Sonnet/Opus recommended).
* **Unified RAG:** Uses LangChain to retrieve context (Confluence docs, DB/API schemas) from a single PGVector store.
* **Enhanced Text-to-SQL:** Generates SQL informed by retrieved database schema context.
* **Databricks Access Control:** Includes a tool to check user permissions *before* executing Databricks SQL queries (requires custom permission system integration).
* **Text-to-GraphQL:** Generates and executes GraphQL queries against a configured API endpoint.
* **Robust API:** Served via FastAPI with improved error handling and health checks.
* **Configuration:** Managed via `.env` file and Pydantic settings.
## Prerequisites
* Python 3.9+
* AWS Account with Bedrock enabled (Sonnet/Opus recommended) and necessary IAM permissions.
* PostgreSQL Database (v12+) with `pgvector` extension enabled.
* **Unified PGVector Collection:** A single table/collection (specified by `PG_COLLECTION_NAME`) in PostgreSQL containing:
* Embeddings for Confluence documents.
* Embeddings for Databricks table/column metadata (e.g., text descriptions of schemas).
* Embeddings for GraphQL schema information (optional, if needed for generation).
* Consider adding a metadata field (e.g., `doc_type: 'confluence' | 'db_schema' | 'gql_schema'`) to documents during ingestion.
* Databricks Workspace with SQL Warehouse and access token.
* **Permission System:** A way to define and check user permissions for Databricks tables (the code provides a placeholder using a JSON file `permissions.json`). You **must** adapt the `check_databricks_permissions` tool in `tools.py` to match your actual system.
* GraphQL API endpoint (if using the GraphQL feature).
## Setup
1. **Clone/Setup Project Files.**
2. **Create/Activate Virtual Environment:** `python -m venv venv && source venv/bin/activate` (or equivalent).
3. **Install Dependencies:** `pip install -r requirements.txt`.
4. **Configure `.env`:**
* Copy `.env.example` (or the sample from the code) to `.env`.
* Fill in **all** required credentials and endpoints (AWS, PG, Databricks, GraphQL).
* Ensure `PG_VECTOR_DIMENSION` matches your Bedrock embedding model.
* Set `PG_COLLECTION_NAME` to your LangChain-compatible collection name.
* Configure `PERMISSION_CONFIG_PATH` (or update tool logic).
* **Secure secrets properly in production!**
5. **Prepare PGVector:**
* Ensure the `vector` extension exists (`CREATE EXTENSION IF NOT EXISTS vector;`).
* Ensure your `PG_COLLECTION_NAME` table exists and is populated with embeddings and content (including DB/API schemas as text documents). Index the embedding column (`CREATE INDEX ON ...`).
6. **Prepare Permissions:**
* Create your `permissions.json` file (or equivalent) based on the example structure.
* **Crucially, modify the placeholder logic** in `tools.py -> check_databricks_permissions` to integrate with your real permission source.
7. **Prepare Databricks:** Ensure warehouse is running and token permissions are correct.
## Running the Application
1. **Start FastAPI:** `uvicorn main:app --reload --host 0.0.0.0 --port 8000`
2. **Interact:**
* Docs: `http://localhost:8000/docs`
* Health: `http://localhost:8000/health`
* Chat: Send POST requests to `http://localhost:8000/chat` with JSON body:
```json
{
"query": "Your natural language query here",
"user_id": "the_actual_user_id"
}
```
**Example `curl`:**
```bash
# Query requiring SQL after permission check
curl -X POST "http://localhost:8000/chat" \
-H "Content-Type: application/json" \
-d '{
"query": "What is the total count of users in the user_activity table from last week?",
"user_id": "analyst007"
}'
# Query requiring GraphQL
curl -X POST "http://localhost:8000/chat" \
-H "Content-Type: application/json" \
-d '{
"query": "Find the project details for project ID P123 using the projects API.",
"user_id": "user123"
}'
# Query requiring RAG from Confluence
curl -X POST "http://localhost:8000/chat" \
-H "Content-Type: application/json" \
-d '{
"query": "Explain the process for submitting expense reports.",
"user_id": "user456"
}'
```
## Key Implementation Notes
* **`tools.py`:** Contains the core logic for interacting with external systems (LangChain RAG, Permissions, Databricks, GraphQL). **Adapt `check_databricks_permissions`!**
* **`agents.py`:** Defines the AutoGen agents. The `assistant_system_message` is critical for guiding the complex workflow. The helper `extract_table_names_from_sql` is basic; a real SQL parser might be needed for complex queries.
* **`main.py`:** Handles API requests, injects `user_id` into the initial agent message, and orchestrates the `initiate_chat` call.
* **Unified Knowledge:** Success heavily depends on populating `pgvector` effectively with both unstructured text (Confluence) and structured metadata (DB/API schemas) in a way the retriever can effectively use.
Refer to the Production Considerations in the previous README for deployment best practices.
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment