Created
July 14, 2025 14:58
-
-
Save KennyVaneetvelde/3377e0aefe89c2e1eb9c5d796a73ef11 to your computer and use it in GitHub Desktop.
Atomic Agents DB query generator agent example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import instructor | |
| import openai | |
| from pydantic import Field | |
| from typing import List, Optional | |
| import sqlite3 | |
| from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseIOSchema | |
| from atomic_agents.lib.components.system_prompt_generator import SystemPromptContextProviderBase, SystemPromptGenerator | |
| from atomic_agents.lib.base.base_tool import BaseTool, BaseToolConfig | |
| class DatabaseSchemaProvider(SystemPromptContextProviderBase): | |
| """Provides database schema info to the agent""" | |
| def __init__(self, db_path: str): | |
| super().__init__(title="Database Schema") | |
| self.db_path = db_path | |
| def get_info(self) -> str: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Get all tables | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = cursor.fetchall() | |
| schema_info = "DATABASE SCHEMA:\n\n" | |
| for table in tables: | |
| table_name = table[0] | |
| cursor.execute(f"PRAGMA table_info({table_name})") | |
| columns = cursor.fetchall() | |
| schema_info += f"Table: {table_name}\n" | |
| schema_info += "Columns:\n" | |
| for col in columns: | |
| schema_info += f" - {col[1]} ({col[2]})\n" | |
| schema_info += "\n" | |
| conn.close() | |
| return schema_info | |
| class TextToSQLInputSchema(BaseIOSchema): | |
| """Input schema for text-to-SQL agent""" | |
| question: str = Field(..., description="Natural language question about the data") | |
| class TextToSQLOutputSchema(BaseIOSchema): | |
| """Output schema with SQL query""" | |
| sql_query: str = Field(..., description="Generated SQL query") | |
| explanation: str = Field(..., description="Brief explanation of what the query does") | |
| class SQLExecutorInputSchema(BaseIOSchema): | |
| """Input for SQL executor tool""" | |
| sql_query: str = Field(..., description="SQL query to execute") | |
| class SQLExecutorOutputSchema(BaseIOSchema): | |
| """Output from SQL executor""" | |
| results: List[dict] = Field(..., description="Query results as list of dictionaries") | |
| row_count: int = Field(..., description="Number of rows returned") | |
| error: Optional[str] = Field(None, description="Error message if query failed") | |
| class SQLExecutorTool(BaseTool): | |
| """Tool that executes SQL queries""" | |
| input_schema = SQLExecutorInputSchema | |
| output_schema = SQLExecutorOutputSchema | |
| def __init__(self, config: BaseToolConfig, db_path: str): | |
| super().__init__(config) | |
| self.db_path = db_path | |
| def run(self, params: SQLExecutorInputSchema) -> SQLExecutorOutputSchema: | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| conn.row_factory = sqlite3.Row # This enables column access by name | |
| cursor = conn.cursor() | |
| cursor.execute(params.sql_query) | |
| rows = cursor.fetchall() | |
| # Convert rows to list of dicts | |
| results = [dict(row) for row in rows] | |
| conn.close() | |
| return SQLExecutorOutputSchema( | |
| results=results, | |
| row_count=len(results), | |
| error=None | |
| ) | |
| except Exception as e: | |
| return SQLExecutorOutputSchema( | |
| results=[], | |
| row_count=0, | |
| error=str(e) | |
| ) | |
| # Create the database schema provider | |
| db_path = "sales.db" | |
| schema_provider = DatabaseSchemaProvider(db_path) | |
| # Create the text-to-SQL agent | |
| text_to_sql_agent = BaseAgent( | |
| config=BaseAgentConfig( | |
| client=instructor.from_openai(openai.OpenAI()), | |
| model="gpt-4o-mini", | |
| system_prompt_generator=SystemPromptGenerator( | |
| background=[ | |
| "You are an expert SQL query generator.", | |
| "You convert natural language questions into SQL queries.", | |
| "You only generate SELECT queries - no modifications allowed." | |
| ], | |
| steps=[ | |
| "Analyze the user's question to understand what data they want", | |
| "Identify the relevant tables and columns from the schema", | |
| "Generate a syntactically correct SQL query", | |
| "Provide a brief explanation of what the query does" | |
| ], | |
| output_instructions=[ | |
| "Always generate valid SQL for SQLite", | |
| "Use proper JOIN conditions when querying multiple tables", | |
| "Include appropriate WHERE clauses and aggregations", | |
| "Never generate INSERT, UPDATE, DELETE, or DROP statements" | |
| ], | |
| context_providers={"db_schema": schema_provider} | |
| ), | |
| input_schema=TextToSQLInputSchema, | |
| output_schema=TextToSQLOutputSchema | |
| ) | |
| ) | |
| # Create the SQL executor tool | |
| sql_executor = SQLExecutorTool( | |
| config=BaseToolConfig(), | |
| db_path=db_path | |
| ) | |
| # Usage example - see how clean this is? | |
| user_question = "What are the top 5 products by total sales revenue?" | |
| # Generate SQL | |
| sql_response = text_to_sql_agent.run( | |
| TextToSQLInputSchema(question=user_question) | |
| ) | |
| print(f"Generated SQL: {sql_response.sql_query}") | |
| print(f"Explanation: {sql_response.explanation}") | |
| # Execute the query | |
| results = sql_executor.run( | |
| SQLExecutorInputSchema(sql_query=sql_response.sql_query) | |
| ) | |
| if results.error: | |
| print(f"Error: {results.error}") | |
| else: | |
| print(f"\nResults ({results.row_count} rows):") | |
| for row in results.results[:5]: # Show first 5 | |
| print(row) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment