Last active
January 4, 2026 06:55
-
-
Save Alphanimble/3d51057bc6fe2e154d2a5a17164e9a9e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import re | |
| import os | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from google.adk.sessions import DatabaseSessionService, InMemorySessionService | |
| from google.adk.runners import Runner | |
| from user_agent.agent import root_agent | |
| from google.genai import types | |
| from dotenv import load_dotenv | |
| import logging | |
| import uuid | |
| import asyncio | |
| import sys | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Retrieve the database URL for sessions | |
| session_db_url = str(os.getenv("SESSION_DB")) | |
| # Configure logging for the application | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def clean_json_response(text: str) -> str: | |
| """ | |
| Removes potential JSON markdown formatting (e.g., ```json\n...\n```) | |
| from LLM responses to ensure valid JSON parsing. | |
| """ | |
| text = re.sub(r"^```json\n", "", text, flags=re.MULTILINE) | |
| text = re.sub(r"\n```$", "", text, flags=re.MULTILINE) | |
| return text.strip() | |
| # Initialize the FastAPI application | |
| app = FastAPI() | |
| # Initialize database session service | |
| try: | |
| # Attempt to initialize DatabaseSessionService with the provided URL | |
| db_session_service = DatabaseSessionService(db_url=session_db_url) #InMemorySessionService is another choice | |
| logger.info(f"DatabaseSessionService initialized with URL: {session_db_url[:50]}...") | |
| except Exception as e: | |
| # Log and raise an error if database session service initialization fails | |
| logger.error(f"Failed to initialize DatabaseSessionService: {e}") | |
| raise | |
| # Add CORS middleware to allow requests from any origin for development/testing | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all HTTP methods (GET, POST, PUT, DELETE, etc.) | |
| allow_headers=["*"], # Allows all headers in the request | |
| ) | |
| # Initialize the ADK Runner with the application name, root agent, and session service | |
| runner = Runner( | |
| app_name="data-agent-app", | |
| agent=root_agent, | |
| session_service=db_session_service, | |
| ) | |
| @app.get("/chart-agent-stream/") | |
| async def process_chart_request_stream( | |
| query: str = Query( | |
| ..., description="The natural language query for the chart agent" | |
| ), | |
| ): | |
| """ | |
| Processes the input query using the chart agent pipeline and streams | |
| the events back to the client as Server-Sent Events (SSE). | |
| """ | |
| user_id = "default" # Default user ID for streaming requests | |
| # Create a new session for the request | |
| session = await db_session_service.create_session( | |
| app_name="data-agent-app", user_id=user_id, session_id=uuid.uuid4().hex | |
| ) | |
| # Create a user message from the input query | |
| user_message = types.Content(role="user", parts=[types.Part(text=query)]) | |
| try: | |
| async def event_stream(): | |
| """ | |
| Asynchronous generator that yields events from the runner as SSE data. | |
| """ | |
| print("event_stream started", file=sys.stderr) | |
| async for event in runner.run_async( | |
| user_id=user_id, new_message=user_message, session_id=session.id | |
| ): | |
| # Yield agent transfer actions | |
| if event.actions.transfer_to_agent: | |
| yield f"data: {event.actions.transfer_to_agent}\n\n" | |
| # Yield function call names | |
| if event.get_function_calls(): | |
| yield f"data: {event.get_function_calls()[-1].name}\n\n" | |
| # Yield function response names | |
| if event.get_function_responses(): | |
| yield f"data: 'response' {event.get_function_responses()[-1].name}\n\n" | |
| # Yield the final response text | |
| if event.is_final_response() and event.content and event.content.parts: | |
| final_response_text = event.content.parts[-1].text | |
| yield f"data: {final_response_text}\n\n" | |
| # Return a StreamingResponse with the event stream | |
| return StreamingResponse(event_stream(), media_type="text/event-stream") | |
| except HTTPException: | |
| # Re-raise HTTPExceptions as they are already handled | |
| raise | |
| except Exception as e: | |
| # Catch and log unexpected errors, then raise an HTTPException | |
| logger.error(f"Unexpected error in chart request stream: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An unexpected error occurred during processing: {str(e)}", | |
| ) | |
| @app.get("/chart-agent/") | |
| async def process_chart_request( | |
| query: str = Query( | |
| ..., description="The natural language query for the chart agent" | |
| ), | |
| session_id: str = Query(default=None, description="Session ID for the request"), | |
| user_id: str = Query(default="default", description="User ID for the session"), | |
| ): | |
| """ | |
| Processes the input query using the chart agent pipeline and attempts | |
| to return the final structured JSON output. | |
| """ | |
| # Create or retrieve a session for the request | |
| session = await db_session_service.create_session( | |
| app_name="data-agent-app", user_id=user_id, session_id=session_id | |
| ) | |
| try: | |
| # Create a user message from the input query | |
| user_message = types.Content(role="user", parts=[types.Part(text=query)]) | |
| final_response_text = None | |
| # Run the agent asynchronously and iterate through events | |
| async for event in runner.run_async( | |
| user_id=user_id, new_message=user_message, session_id=session.id | |
| ): | |
| # Capture the final response text when available | |
| if event.is_final_response() and event.content and event.content.parts: | |
| final_response_text = event.content.parts[-1].text | |
| break | |
| # If no final response text was produced, raise an error | |
| if final_response_text is None: | |
| logger.error("Agent pipeline did not produce a final text response") | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Agent pipeline did not produce a final text response.", | |
| ) | |
| # Attempt to clean and parse the final text as JSON | |
| try: | |
| cleaned_response = clean_json_response(final_response_text) | |
| parsed_json = json.loads(cleaned_response) | |
| logger.info("Successfully parsed JSON response") | |
| return parsed_json | |
| except json.JSONDecodeError as e: | |
| # Handle JSON decoding errors | |
| logger.error(f"JSON decode error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": "The final agent response was not valid JSON.", | |
| "raw_response": final_response_text, | |
| "json_error": str(e), | |
| }, | |
| ) | |
| except HTTPException: | |
| # Re-raise HTTPExceptions | |
| raise | |
| except Exception as e: | |
| # Catch and log unexpected errors, then raise an HTTPException | |
| logger.error(f"Unexpected error in chart request: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An unexpected error occurred during processing: {str(e)}", | |
| ) | |
| @app.get("/table-agent/") | |
| async def process_table_request( | |
| query: str = Query( | |
| ..., description="The natural language query for the simple table agent" | |
| ), | |
| session_id: str = Query(..., description="The session ID for the simple table agent"), | |
| user_id: str = Query(..., description="The user ID for the simple table agent"), | |
| ): | |
| """ | |
| Processes the input query using the simple table agent pipeline | |
| for straightforward single-dimensional table requests. | |
| """ | |
| # Create or retrieve a session for the request | |
| session = await db_session_service.create_session( | |
| app_name=runner.app_name, user_id=user_id, session_id=session_id | |
| ) | |
| try: | |
| logger.info(f"Created session {session.id} for simple table request") | |
| # Direct message to table_agent for simple requests | |
| user_message = types.Content( | |
| role="user", parts=[types.Part(text=f"Create a table for: {query}")] | |
| ) | |
| final_response_text = None | |
| # Run the agent asynchronously and iterate through events | |
| async for event in runner.run_async( | |
| user_id=user_id, new_message=user_message, session_id=session.id | |
| ): | |
| # Capture the final response text when available | |
| if event.is_final_response() and event.content and event.content.parts: | |
| final_response_text = event.content.parts[-1].text | |
| break | |
| # If no final response text was produced, raise an error | |
| if final_response_text is None: | |
| logger.error("Agent pipeline did not produce a final text response") | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Agent pipeline did not produce a final text response.", | |
| ) | |
| # Clean up any markdown formatting and return the table | |
| cleaned_response = final_response_text.strip() | |
| logger.info("Successfully processed simple table request") | |
| return {"table": cleaned_response, "type": "simple"} | |
| except HTTPException: | |
| # Re-raise HTTPExceptions | |
| raise | |
| except Exception as e: | |
| # Catch and log unexpected errors, then raise an HTTPException | |
| logger.error(f"Unexpected error in simple table request: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An unexpected error occurred during processing: {str(e)}", | |
| ) | |
| @app.get("/health/") | |
| async def health_check(): | |
| """Basic health check endpoint to verify the API is running.""" | |
| return {"status": "ok", "app_name": runner.app_name} | |
| @app.get("/test-stream/") | |
| async def test_stream(): | |
| """ | |
| A simple test endpoint that streams numbers 0-4 with a 1-second delay | |
| between each, demonstrating Server-Sent Events (SSE). | |
| """ | |
| async def simple_stream(): | |
| for i in range(5): | |
| yield f"data: {i}\n\n" | |
| await asyncio.sleep(1) | |
| return StreamingResponse(simple_stream(), media_type="text/event-stream") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Run the FastAPI application using Uvicorn | |
| uvicorn.run( | |
| app, host="0.0.0.0", port=8000, reload=True | |
| ) # Set reload=True for development |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment