Skip to content

Instantly share code, notes, and snippets.

@yai333
Created September 7, 2025 10:45
Show Gist options
  • Select an option

  • Save yai333/0a063b7198f83d805ac244d2864ea83d to your computer and use it in GitHub Desktop.

Select an option

Save yai333/0a063b7198f83d805ac244d2864ea83d to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# Requirements
# pydantic >= 2.5.0
# google-adk >= 1.13.0
# google-generativeai >= 0.8.5
# presidio-analyzer >= 2.2.354
# presidio-anonymizer >= 2.2.354
# spacy >= 3.4.0
# typing-extensions >= 4.8.0
import asyncio
import json
import os
import re
import sqlite3
import time
from enum import Enum
from typing import List, Literal, Union
from pydantic import BaseModel, Field, field_validator, create_model
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine
from google.adk import Runner
from google.adk.agents import LlmAgent
from google.adk.sessions import InMemorySessionService
from google.genai import types
class ParameterType(Enum):
STRING = "string"
INTEGER = "integer"
FLOAT = "float"
DATE = "date"
BOOLEAN = "boolean"
class SQLParameter(BaseModel):
param_name: str = Field(
description="Parameter name used in SQL (e.g., 'first_name', 'last_name')")
placeholder_ref: str = Field(
description="Reference to the original placeholder (e.g., '<PERSON_1>')")
param_type: ParameterType = Field(
default=ParameterType.STRING, description="SQL parameter data type")
class StructuredSQLQuery(BaseModel):
"""Structured SQL query with parameter validation."""
sql_query: str = Field(
description="SQL query with named parameters for PII (e.g., :person_1, :email_1)")
parameters: List[SQLParameter] = Field(
description="List of parameters with their placeholder mappings")
@field_validator('sql_query')
@classmethod
def validate_sql_structure(cls, v):
v = v.strip()
if not v:
raise ValueError("SQL query cannot be empty")
upper_v = v.upper()
valid_starts = ('SELECT', 'WITH', '(')
if not any(upper_v.startswith(start) for start in valid_starts):
raise ValueError("SQL must be a valid SQL statement")
return v
def extract_placeholders_from_text(text: str) -> List[str]:
if not text:
return []
placeholders = re.findall(r'<[A-Z_]+_\d+>', text)
return sorted(list(set(placeholders)))
def create_dynamic_sql_models(anonymized_question: str):
"""Create dynamic Pydantic models with placeholder constraints based on the anonymized question."""
available_placeholders = extract_placeholders_from_text(
anonymized_question)
if not available_placeholders:
return SQLParameter, StructuredSQLQuery
# Union type: PII placeholders OR any string for literals
PlaceholderType = Union[Literal[tuple(available_placeholders)], str]
DynamicSQLParameter = create_model(
'DynamicSQLParameter',
param_name=(str, Field(
description="Parameter name used in SQL (e.g., 'first_name', 'last_name')")),
placeholder_ref=(PlaceholderType, Field(
description=f"PII placeholders: {', '.join(available_placeholders)} OR literal values like '100', 'Jazz'")),
param_type=(ParameterType, Field(
default=ParameterType.STRING, description="SQL parameter data type")),
__base__=BaseModel
)
DynamicStructuredSQLQuery = create_model(
'DynamicStructuredSQLQuery',
sql_query=(str, Field(
description="SQL query with named parameters for PII (e.g., :person_1, :email_1)")),
parameters=(List[DynamicSQLParameter], Field(
description="List of parameters with their placeholder mappings")),
__base__=BaseModel
)
return DynamicSQLParameter, DynamicStructuredSQLQuery
class MappingStorage:
"""Persistent storage for PII mappings using SQLite
NOTE: This is a simplified implementation for demo purposes.
In production, you should:
1. Add session_id to the schema for session isolation
2. Include user_id for multi-tenant scenarios
3. Add TTL/expiration for compliance (GDPR right to be forgotten)
4. Encrypt sensitive mappings at rest
Production schema example:
CREATE TABLE pii_mappings (
session_id TEXT NOT NULL,
user_id TEXT NOT NULL,
original TEXT NOT NULL,
entity_type TEXT NOT NULL,
pseudonym TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
PRIMARY KEY (session_id, original, entity_type)
)
"""
def __init__(self, db_path="pii_mappings.db"):
self.db_path = db_path
self._init_db()
def _init_db(self):
"""Initialize database with mappings table
NOTE: For demo purposes, mappings are global and persistent.
Production should scope mappings by session_id.
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS pii_mappings (
original TEXT NOT NULL,
entity_type TEXT NOT NULL,
pseudonym TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (original, entity_type)
)
""")
conn.commit()
conn.close()
def store_mapping(self, original: str, pseudonym: str, entity_type: str):
"""Store PII mapping
NOTE: Currently stores globally without session isolation.
Production should include session_id in the storage:
INSERT OR REPLACE INTO pii_mappings
(session_id, original, entity_type, pseudonym)
VALUES (?, ?, ?, ?)
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO pii_mappings
(original, entity_type, pseudonym)
VALUES (?, ?, ?)
""", (original, entity_type, pseudonym))
conn.commit()
conn.close()
def get_mapping(self, original: str, entity_type: str):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT original, pseudonym, entity_type FROM pii_mappings
WHERE original = ? AND entity_type = ?
""", (original, entity_type))
result = cursor.fetchone()
conn.close()
if result:
return {
'original': result[0],
'pseudonym': result[1],
'type': result[2]
}
return None
def get_all_mappings(self) -> dict:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT original, pseudonym, entity_type FROM pii_mappings")
results = cursor.fetchall()
conn.close()
mappings = {}
for i, result in enumerate(results):
mappings[f"mapping_{i}"] = {
'original': result[0],
'pseudonym': result[1],
'type': result[2]
}
return mappings
def clear_mappings(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM pii_mappings")
conn.commit()
conn.close()
class PresidioPIIDetector:
def __init__(self, db_path="pii_mappings.db"):
self.analyzer = AnalyzerEngine()
self.anonymizer = AnonymizerEngine()
self.mapping_storage = MappingStorage(db_path)
# NOTE: Session mappings are in-memory only for demo purposes
# Production should store these in DB with proper session_id scoping
self.session_mappings = {}
def pseudonymize(self, text):
allowed_entities = [
'EMAIL_ADDRESS', 'PERSON', 'DATE_TIME', 'LOCATION',
'PHONE_NUMBER',
'AU_ABN',
'AU_ACN',
'AU_TFN',
'AU_MEDICARE'
]
results = self.analyzer.analyze(
text=text, language='en', entities=allowed_entities)
if not results:
return text, []
anonymized_text = text
detected_entities = []
session_entity_counters = {}
for result in sorted(results, key=lambda x: x.start, reverse=True):
original_value = text[result.start:result.end]
entity_type = result.entity_type.lower()
existing_mapping = self.mapping_storage.get_mapping(
original_value, entity_type)
if existing_mapping:
placeholder = existing_mapping['pseudonym']
else:
if entity_type not in session_entity_counters:
session_entity_counters[entity_type] = self._get_next_entity_counter(
entity_type)
else:
session_entity_counters[entity_type] += 1
next_counter = session_entity_counters[entity_type]
placeholder = f"<{entity_type.upper()}_{next_counter}>"
self.mapping_storage.store_mapping(
original_value, placeholder, entity_type)
anonymized_text = (
anonymized_text[:result.start] +
placeholder +
anonymized_text[result.end:]
)
self.session_mappings[f"{original_value}_{entity_type}"] = {
'original': original_value,
'pseudonym': placeholder,
'type': entity_type,
'score': result.score
}
detected_entities.append({
'entity_type': entity_type,
'value': original_value,
'start': result.start,
'end': result.end,
'score': result.score
})
return anonymized_text, detected_entities
def _get_next_entity_counter(self, entity_type: str) -> int:
"""Get the next available counter for an entity type across all existing mappings."""
all_mappings = self.mapping_storage.get_all_mappings()
existing_counters = []
for mapping in all_mappings.values():
if mapping['type'] == entity_type:
pseudonym = mapping['pseudonym']
match = re.search(rf'<{entity_type.upper()}_(\d+)>', pseudonym)
if match:
existing_counters.append(int(match.group(1)))
for mapping in self.session_mappings.values():
if mapping['type'] == entity_type:
pseudonym = mapping['pseudonym']
match = re.search(rf'<{entity_type.upper()}_(\d+)>', pseudonym)
if match:
existing_counters.append(int(match.group(1)))
if existing_counters:
return max(existing_counters) + 1
else:
return 1
def deanonymize(self, anonymized_text):
result = anonymized_text
for mapping in self.session_mappings.values():
if mapping['pseudonym'] in result:
result = result.replace(
mapping['pseudonym'], mapping['original'])
for mapping in self.mapping_storage.get_all_mappings().values():
if mapping['pseudonym'] in result:
result = result.replace(
mapping['pseudonym'], mapping['original'])
return result
def get_mapping_storage(self):
return self.session_mappings
def clear_all_mappings(self):
self.session_mappings.clear()
self.mapping_storage.clear_mappings()
class SQLResponse(BaseModel):
"""SQL response with dynamic bindings"""
sql: str = Field(...,
description="SQLite SQL with dynamic placeholder bindings")
bindings: dict = Field(default_factory=dict,
description="Only includes PII types actually used")
CHINOOK_SCHEMA = """
Chinook Music Database (SQLite) with Relationships:
- Customer: CustomerId(PK), FirstName, LastName, Email, Phone, Country
- Invoice: InvoiceId(PK), CustomerId(FK→Customer), InvoiceDate, Total
- InvoiceLine: InvoiceLineId(PK), InvoiceId(FK→Invoice), TrackId(FK→Track), Quantity, UnitPrice
- Track: TrackId(PK), Name, AlbumId(FK→Album), GenreId(FK→Genre), Milliseconds, UnitPrice
- Album: AlbumId(PK), Title, ArtistId(FK→Artist)
- Artist: ArtistId(PK), Name
- Genre: GenreId(PK), Name
Common JOINs:
- Customer→Invoice: ON c.CustomerId = i.CustomerId
- Invoice→InvoiceLine: ON i.InvoiceId = il.InvoiceId
- InvoiceLine→Track: ON il.TrackId = t.TrackId
- Track→Album: ON t.AlbumId = a.AlbumId
- Album→Artist: ON a.ArtistId = ar.ArtistId
- Track→Genre: ON t.GenreId = g.GenreId
"""
class TextToSQLADK:
"""Text-to-SQL with Google ADK and PII protection"""
def __init__(self, api_key: str):
self.pii_detector = PresidioPIIDetector()
self.app_name = "chinook-sql-final"
self.user_id = "test-user-final"
self.types = types
self.pii_instruction = f"""Convert natural language to SQLite SQL for the Chinook database with PII placeholders.
{CHINOOK_SCHEMA}
Use SQLite syntax with exact table/column names. Use NAMED PARAMETERS for ALL WHERE clause values.
Use proper JOINs for multi-table queries.
PARAMETER RULES:
- Use named parameters for ALL values in WHERE, HAVING, and conditional clauses: :param_1, :param_2, etc.
- Map each :param_name to its source placeholder
IMPORTANT NAME HANDLING:
- When searching by person names (PERSON placeholders), ALWAYS use FULL NAME matching
- Use: (FirstName || ' ' || LastName) = :person_1
- NOT just FirstName = :person_1
Examples:
- Mixed: "SELECT * FROM Customer WHERE (FirstName || ' ' || LastName) = :person_1 AND Country = :country_1"
- Pure PII: "SELECT * FROM Customer WHERE Email = :email_1 AND Phone = :phone_1" """
self.no_pii_instruction = f"""Convert natural language to SQLite SQL for the Chinook database.
{CHINOOK_SCHEMA}
Use SQLite syntax with exact table/column names. Use LITERAL VALUES directly in SQL.
Use proper JOINs for multi-table queries.
RULES:
- Use literal string values directly in SQL with proper quotes
- Leave parameters array EMPTY
- Use proper SQL syntax: 'text values', numbers without quotes
Examples:
- Country: "SELECT * FROM Customer WHERE Country = 'Canada'"
- Gmail pattern: "SELECT * FROM Customer WHERE Email LIKE '%@gmail.com'"
- Artist: "SELECT * FROM Album a JOIN Artist ar ON a.ArtistId = ar.ArtistId WHERE ar.Name = 'AC/DC'"
- Number: "SELECT * FROM Track WHERE Milliseconds > 300000"
Parameters: [] (always empty)"""
self.sql_agent = None
self.session_service = InMemorySessionService()
def normalize_binding_format(self, parsed_data: dict, pii_items: list) -> dict:
"""Convert model's natural binding format to expected schema format"""
result = {'sql': parsed_data.get('sql', ''), 'bindings': {}}
original_bindings = parsed_data.get('bindings', {})
if isinstance(original_bindings, list):
pii_by_type = {}
for item in pii_items:
pii_type = item.get('entity_type', 'unknown').lower()
pii_by_type.setdefault(pii_type, []).append(item)
for i, value in enumerate(original_bindings):
for pii_type, items in pii_by_type.items():
if i < len(items):
result['bindings'].setdefault(pii_type, []).append(
f"<{pii_type.upper()}_{len(result['bindings'].get(pii_type, [])) + 1}>")
break
elif isinstance(original_bindings, dict) and original_bindings:
if all(isinstance(v, str) and v.startswith('<') and v.endswith('>') for v in original_bindings.values()):
result['bindings'] = original_bindings
elif all(isinstance(v, list) for v in original_bindings.values()):
result['bindings'] = original_bindings
else:
print(
f"Warning: Received unrecognized bindings format: {original_bindings}")
elif isinstance(original_bindings, dict):
result['bindings'] = original_bindings
return result
async def generate_sql(self, query: str, pii_items: list = []) -> dict:
"""Generate SQL with ADK using dynamic structured output"""
try:
_, DynamicStructuredSQLQuery = create_dynamic_sql_models(query)
print(f"πŸ”§ Dynamic Schema for: '{query}'")
print(f"πŸ“‹ Schema Model: {DynamicStructuredSQLQuery.__name__}")
# available_placeholders = extract_placeholders_from_text(query)
# if available_placeholders:
# print(f"🏷️ Allowed Placeholders: {available_placeholders}")
# else:
# print("🏷️ No PII Placeholders - Using base schema")
# print("πŸ“ Schema Fields:")
for field_name, field_info in DynamicStructuredSQLQuery.model_fields.items():
print(f" - {field_name}: {field_info.annotation}")
print()
has_pii = len(pii_items) > 0
instruction = self.pii_instruction if has_pii else self.no_pii_instruction
agent_description = "Generates structured SQL with PII placeholder handling" if has_pii else "Generates structured SQL with literal values"
print(
f"πŸ“‹ Using {'PII' if has_pii else 'Non-PII'} Instruction Template")
sql_agent = LlmAgent(
model="gemini-2.5-flash-lite",
name="dynamic_sql_agent",
description=agent_description,
instruction=instruction,
output_schema=DynamicStructuredSQLQuery,
output_key="structured_sql_result",
tools=[],
generate_content_config=self.types.GenerateContentConfig(
temperature=0
),
disallow_transfer_to_parent=True,
disallow_transfer_to_peers=True
)
runner = Runner(
agent=sql_agent,
session_service=self.session_service,
app_name=self.app_name
)
session = await self.session_service.create_session(
app_name=self.app_name,
user_id=self.user_id
)
content = self.types.Content(
role='user',
parts=[self.types.Part(text=f"Convert to SQL: {query}")]
)
events = runner.run(
user_id=self.user_id,
session_id=session.id,
new_message=content
)
structured_response = None
for event in events:
if event.is_final_response() and event.content:
response_text = event.content.parts[0].text.strip()
print(f"πŸ” ADK Structured Response: {response_text}")
try:
if isinstance(response_text, str):
clean_response = re.sub(
r'```(?:json)?\s*', '', response_text)
clean_response = re.sub(
r'\s*```', '', clean_response)
structured_response = json.loads(
clean_response.strip())
else:
structured_response = response_text
# print(
# f"πŸ” Parsed Structured Data: {json.dumps(structured_response, indent=2)}")
break
except json.JSONDecodeError as e:
print(f"πŸ” JSON parsing error: {e}")
print(f"πŸ” Raw response: {response_text}")
raise e
if structured_response:
result = {
'sql': structured_response.get('sql_query', ''),
'parameters': structured_response.get('parameters', []),
'session_id': session.id[:8]
}
bindings = {}
for param in result.get('parameters', []):
param_name = param.get('param_name', '')
placeholder_ref = param.get('placeholder_ref', '')
if param_name and placeholder_ref:
bindings[param_name] = placeholder_ref
result['bindings'] = bindings
# print(
# f"πŸ” Final Structured Result: {json.dumps(result, indent=2)}")
return result
raise Exception("ADK produced no structured response")
except Exception as e:
print(f"ADK error: {str(e)}")
raise e
def generate_sql_sync(self, query: str, pii_items: list = []) -> dict:
return asyncio.run(self.generate_sql(query, pii_items))
def validate_sql(self, sql: str, bindings: dict = None) -> dict:
if not sql or sql.startswith('--'):
return {'valid': False, 'error': 'Invalid SQL'}
try:
# download from https://www.kaggle.com/datasets/nancyalaswad90/chinook-sample-database
conn = sqlite3.connect("chinook.db")
cursor = conn.cursor()
clean_sql = sql.rstrip(';')
if bindings:
dummy_bindings = {}
for param_name, placeholder in bindings.items():
param_key = param_name.lstrip(':')
dummy_bindings[param_key] = 'test_value'
cursor.execute(
f"EXPLAIN QUERY PLAN {clean_sql}", dummy_bindings)
if clean_sql.upper().startswith('SELECT'):
if 'LIMIT' not in clean_sql.upper():
clean_sql += ' LIMIT 3'
cursor.execute(clean_sql, dummy_bindings)
cursor.fetchall()
else:
cursor.execute(f"EXPLAIN QUERY PLAN {clean_sql}")
if clean_sql.upper().startswith('SELECT'):
if 'LIMIT' not in clean_sql.upper():
clean_sql += ' LIMIT 3'
cursor.execute(clean_sql)
cursor.fetchall()
conn.close()
return {'valid': True, 'row_count': 3}
except Exception as e:
return {'valid': False, 'error': str(e)}
def process_query(self, query: str) -> dict:
"""Process query with PII protection and SQL generation"""
start_time = time.time()
anonymized, pii_entities = self.pii_detector.pseudonymize(query)
pii_items = []
for entity in pii_entities:
pii_items.append({
'original': entity['value'],
'pseudonym': f"<{entity['entity_type'].upper()}_1>",
'type': entity['entity_type']
})
sql_result = self.generate_sql_sync(anonymized, pii_items)
# Extract non-PII parameters from the SQL result
non_pii_items = []
if 'parameters' in sql_result:
pii_placeholders = {item['pseudonym'] for item in pii_items}
for param in sql_result['parameters']:
placeholder_ref = param.get('placeholder_ref', '')
# If placeholder_ref is not a PII placeholder, it's a literal value
if placeholder_ref not in pii_placeholders and not placeholder_ref.startswith('<'):
non_pii_items.append({
'param_name': param.get('param_name', ''),
'value': placeholder_ref,
'type': param.get('param_type', 'string')
})
validation = self.validate_sql(sql_result.get(
'sql', ''), sql_result.get('bindings', {}))
result = {
'query': query,
'anonymized': anonymized,
'sql': sql_result.get('sql', ''),
'bindings': sql_result.get('bindings', {}),
'pii_items': pii_items,
'non_pii_items': non_pii_items,
'valid': validation['valid'],
'error': validation.get('error', None),
'confidence': sql_result.get('confidence', 0.0),
'execution_time': round((time.time() - start_time) * 1000, 1)
}
return result
def main():
print("=" * 70)
print("πŸš€ GOOGLE ADK TEXT-TO-SQL TEST")
print("With PII protection and session management")
print("=" * 70)
adk = TextToSQLADK(os.environ["GOOGLE_API_KEY"])
print("βœ… Final ADK implementation ready")
test_queries = [
# PII with mixed literal values - tests Option 2 parameterization
# "Find customer John Smith in Australia",
# "Find customers named Bob Wilson with email bob@gmail.com in Australia",
# Pure PII queries
# "Find customer with email alice@test.com",
# "Show customers with phone number 555-999-8888",
# "List invoices for customer Mary Davis with email mary@example.com and phone 555-123-4567",
# "Find customers named Jane Smith with phone 555-444-3333 in location Paris",
# "Show all customers with emails john@gmail.com and sarah@yahoo.com",
# Pure non-PII queries
# "Show albums by AC/DC",
# "Find all customers in Canada",
# "List tracks longer than 300 seconds",
# "Show albums released in year 2000",
# "Find genres with more than 100 tracks",
# Complex mixed scenarios - Multi-PII + literals
"Find customers named Sarah Johnson in Canada who purchased rock albums",
"List customers named John Doe in USA with total invoice amount greater than 1000",
"List customers with phone 555-111-2222 or email test@example.com in France",
"Show customers named Michael Johnson in Australia who bought jazz albums costing more than 100"
]
print(f"Testing {len(test_queries)} queries...")
print()
results = []
for i, query in enumerate(test_queries, 1):
print("=" * 70)
print(f"TEST {i}: {query}")
print("=" * 70)
result = adk.process_query(query)
results.append(result)
print(f"Anonymized: {result['anonymized']}")
print(f"SQL: {result['sql']}")
print(f"PII ({len(result['pii_items'])}):")
for pii in result['pii_items']:
print(f" {pii['original']} β†’ {pii['pseudonym']} ({pii['type']})")
if result.get('non_pii_items'):
print(f"Non-PII ({len(result['non_pii_items'])}):")
for item in result['non_pii_items']:
print(
f" {item['param_name']} β†’ {item['value']} ({item['type']})")
print(f"Valid SQL: {'βœ…' if result['valid'] else '❌'}")
if result['error']:
print(f"Error: {result['error']}")
print(f"Time: {result['execution_time']}ms")
print()
print("=" * 70)
print("🎯 FINAL SUMMARY")
print("=" * 70)
total_tests = len(results)
valid_sql = sum(1 for r in results if r['valid'])
avg_confidence = sum(r['confidence'] for r in results) / \
total_tests if total_tests > 0 else 0
total_pii = sum(len(r['pii_items']) for r in results)
total_non_pii = sum(len(r.get('non_pii_items', [])) for r in results)
print(f"Tests: {total_tests}")
print(
f"Valid SQL: {valid_sql}/{total_tests} ({valid_sql/total_tests*100:.0f}%)")
print(f"Avg Confidence: {avg_confidence:.2f}")
print(f"Total PII Protected: {total_pii}")
print(f"Total Non-PII Parameters: {total_non_pii}")
print()
print("πŸ† Google ADK Status:")
if valid_sql == total_tests:
print(" βœ… PERFECT: All SQL queries valid")
elif valid_sql > 0:
print(" ⚠️ PARTIAL: ADK running but SQL needs improvement")
else:
print(" ❌ ISSUES: ADK running but SQL validation failing")
success = valid_sql > total_tests * 0.8
print(f"\n{'βœ… PASSED' if success else '❌ FAILED'}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment