Skip to content

Instantly share code, notes, and snippets.

@thorwhalen
Created March 21, 2025 07:50
Show Gist options
  • Select an option

  • Save thorwhalen/36092351f7cc535326e169e108083caa to your computer and use it in GitHub Desktop.

Select an option

Save thorwhalen/36092351f7cc535326e169e108083caa to your computer and use it in GitHub Desktop.
"""
This module defines a FastAPI application that demonstrates a simple pipeline
processing system. The system allows users to specify which stages they want
to include in the processing pipeline, and the system will automatically
compute and return the results for those stages.
"""
from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel
from typing import (
List,
Dict,
Any,
Optional,
Callable,
AsyncGenerator,
TypeVar,
Awaitable,
)
from dataclasses import dataclass
from enum import Enum, auto
import asyncio
app = FastAPI()
T = TypeVar('T')
Context = Dict[str, Any]
class ProcessingStage(Enum):
EMBEDDING = auto()
CLUSTERING = auto()
PROJECTION = auto()
class PipelineRequest(BaseModel):
"""Request model for pipeline processing"""
text: str
stages: List[ProcessingStage]
async def _get_embeddings(text: str) -> List[float]:
"""Generate vector embeddings from text
>>> import asyncio
>>> embeddings = asyncio.run(_get_embeddings("test"))
>>> len(embeddings) > 0
True
"""
# In a real implementation, this would call your embedding service
await asyncio.sleep(0.1) # Simulate network call
return [0.1, 0.2, 0.3] # Simplified example
async def _get_clusters(embeddings: List[float]) -> Dict[str, Any]:
"""Generate clusters from embeddings"""
# In a real implementation, this would call your clustering service
await asyncio.sleep(0.1) # Simulate network call
return {"clusters": [0, 1, 0]} # Simplified example
async def _get_projections(embeddings: List[float]) -> Dict[str, Any]:
"""Generate 2D projections from embeddings"""
# In a real implementation, this would call your projection service
await asyncio.sleep(0.1) # Simulate network call
return {"x": [1, 2, 3], "y": [4, 5, 6]} # Simplified example
async def _process_stage(stage: ProcessingStage, context: Context) -> Context:
"""Process a single pipeline stage"""
if stage == ProcessingStage.EMBEDDING:
if "text" not in context:
raise ValueError("Text required for embedding")
context["embeddings"] = await _get_embeddings(context["text"])
elif stage == ProcessingStage.CLUSTERING:
if "embeddings" not in context:
raise ValueError("Embeddings required for clustering")
context["clusters"] = await _get_clusters(context["embeddings"])
elif stage == ProcessingStage.PROJECTION:
if "embeddings" not in context:
raise ValueError("Embeddings required for projection")
context["projections"] = await _get_projections(context["embeddings"])
return context
async def process_pipeline(text: str, stages: List[ProcessingStage]) -> Context:
"""Process all requested stages for the pipeline
Args:
text: The input text to process
stages: List of processing stages to apply
Returns:
Dictionary containing results from all requested stages
"""
context = {"text": text}
# Ensure EMBEDDING is always first if needed
ordered_stages = list(stages)
if (
ProcessingStage.CLUSTERING in ordered_stages
or ProcessingStage.PROJECTION in ordered_stages
):
if ProcessingStage.EMBEDDING not in ordered_stages:
ordered_stages.insert(0, ProcessingStage.EMBEDDING)
# Process each stage in sequence
for stage in ordered_stages:
context = await _process_stage(stage, context)
# Only return the requested outputs
result = {}
if ProcessingStage.EMBEDDING in stages:
result["embeddings"] = context.get("embeddings")
if ProcessingStage.CLUSTERING in stages:
result["clusters"] = context.get("clusters")
if ProcessingStage.PROJECTION in stages:
result["projections"] = context.get("projections")
return result
# Individual service endpoints
@app.post("/embeddings")
async def embedding_endpoint(text: str):
"""Generate embeddings from text"""
embeddings = await _get_embeddings(text)
return {"embeddings": embeddings}
@app.post("/clusters")
async def clusters_endpoint(embeddings: List[float]):
"""Generate clusters from embeddings"""
clusters = await _get_clusters(embeddings)
return clusters
@app.post("/projections")
async def projections_endpoint(embeddings: List[float]):
"""Generate 2D projections from embeddings"""
projections = await _get_projections(embeddings)
return projections
# Combined pipeline endpoint
@app.post("/pipeline")
async def pipeline_endpoint(request: PipelineRequest):
"""Process text through a customizable pipeline
The user can specify which stages they want included in the results.
Intermediate results (like embeddings) are computed but not returned
unless specifically requested.
"""
try:
result = await process_pipeline(request.text, request.stages)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment