Skip to content

Instantly share code, notes, and snippets.

@KennyVaneetvelde
Created July 14, 2025 14:58
Show Gist options
  • Select an option

  • Save KennyVaneetvelde/3377e0aefe89c2e1eb9c5d796a73ef11 to your computer and use it in GitHub Desktop.

Select an option

Save KennyVaneetvelde/3377e0aefe89c2e1eb9c5d796a73ef11 to your computer and use it in GitHub Desktop.
Atomic Agents DB query generator agent example
import instructor
import openai
from pydantic import Field
from typing import List, Optional
import sqlite3
from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseIOSchema
from atomic_agents.lib.components.system_prompt_generator import SystemPromptContextProviderBase, SystemPromptGenerator
from atomic_agents.lib.base.base_tool import BaseTool, BaseToolConfig
class DatabaseSchemaProvider(SystemPromptContextProviderBase):
"""Provides database schema info to the agent"""
def __init__(self, db_path: str):
super().__init__(title="Database Schema")
self.db_path = db_path
def get_info(self) -> str:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Get all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_info = "DATABASE SCHEMA:\n\n"
for table in tables:
table_name = table[0]
cursor.execute(f"PRAGMA table_info({table_name})")
columns = cursor.fetchall()
schema_info += f"Table: {table_name}\n"
schema_info += "Columns:\n"
for col in columns:
schema_info += f" - {col[1]} ({col[2]})\n"
schema_info += "\n"
conn.close()
return schema_info
class TextToSQLInputSchema(BaseIOSchema):
"""Input schema for text-to-SQL agent"""
question: str = Field(..., description="Natural language question about the data")
class TextToSQLOutputSchema(BaseIOSchema):
"""Output schema with SQL query"""
sql_query: str = Field(..., description="Generated SQL query")
explanation: str = Field(..., description="Brief explanation of what the query does")
class SQLExecutorInputSchema(BaseIOSchema):
"""Input for SQL executor tool"""
sql_query: str = Field(..., description="SQL query to execute")
class SQLExecutorOutputSchema(BaseIOSchema):
"""Output from SQL executor"""
results: List[dict] = Field(..., description="Query results as list of dictionaries")
row_count: int = Field(..., description="Number of rows returned")
error: Optional[str] = Field(None, description="Error message if query failed")
class SQLExecutorTool(BaseTool):
"""Tool that executes SQL queries"""
input_schema = SQLExecutorInputSchema
output_schema = SQLExecutorOutputSchema
def __init__(self, config: BaseToolConfig, db_path: str):
super().__init__(config)
self.db_path = db_path
def run(self, params: SQLExecutorInputSchema) -> SQLExecutorOutputSchema:
try:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row # This enables column access by name
cursor = conn.cursor()
cursor.execute(params.sql_query)
rows = cursor.fetchall()
# Convert rows to list of dicts
results = [dict(row) for row in rows]
conn.close()
return SQLExecutorOutputSchema(
results=results,
row_count=len(results),
error=None
)
except Exception as e:
return SQLExecutorOutputSchema(
results=[],
row_count=0,
error=str(e)
)
# Create the database schema provider
db_path = "sales.db"
schema_provider = DatabaseSchemaProvider(db_path)
# Create the text-to-SQL agent
text_to_sql_agent = BaseAgent(
config=BaseAgentConfig(
client=instructor.from_openai(openai.OpenAI()),
model="gpt-4o-mini",
system_prompt_generator=SystemPromptGenerator(
background=[
"You are an expert SQL query generator.",
"You convert natural language questions into SQL queries.",
"You only generate SELECT queries - no modifications allowed."
],
steps=[
"Analyze the user's question to understand what data they want",
"Identify the relevant tables and columns from the schema",
"Generate a syntactically correct SQL query",
"Provide a brief explanation of what the query does"
],
output_instructions=[
"Always generate valid SQL for SQLite",
"Use proper JOIN conditions when querying multiple tables",
"Include appropriate WHERE clauses and aggregations",
"Never generate INSERT, UPDATE, DELETE, or DROP statements"
],
context_providers={"db_schema": schema_provider}
),
input_schema=TextToSQLInputSchema,
output_schema=TextToSQLOutputSchema
)
)
# Create the SQL executor tool
sql_executor = SQLExecutorTool(
config=BaseToolConfig(),
db_path=db_path
)
# Usage example - see how clean this is?
user_question = "What are the top 5 products by total sales revenue?"
# Generate SQL
sql_response = text_to_sql_agent.run(
TextToSQLInputSchema(question=user_question)
)
print(f"Generated SQL: {sql_response.sql_query}")
print(f"Explanation: {sql_response.explanation}")
# Execute the query
results = sql_executor.run(
SQLExecutorInputSchema(sql_query=sql_response.sql_query)
)
if results.error:
print(f"Error: {results.error}")
else:
print(f"\nResults ({results.row_count} rows):")
for row in results.results[:5]: # Show first 5
print(row)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment