Skip to content

Instantly share code, notes, and snippets.

@thorwhalen
Created March 21, 2025 07:54
Show Gist options
  • Select an option

  • Save thorwhalen/e91fcaee1fc9e0b3dca6d12a9d6fcca2 to your computer and use it in GitHub Desktop.

Select an option

Save thorwhalen/e91fcaee1fc9e0b3dca6d12a9d6fcca2 to your computer and use it in GitHub Desktop.
"""
Vector Processing Pipeline Service
This module provides a FastAPI service for text segment processing through various
stages including embedding, clustering, and dimensionality reduction (planarization).
The service allows for:
- Individual processing of segments through specific stages
- Pipeline processing that combines multiple stages efficiently
- Customizable function selection for each processing stage
- Flexible output field naming
"""
from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Callable, Mapping, Tuple, Union, TypeVar
from dataclasses import dataclass
from enum import Enum
from functools import wraps
import asyncio
app = FastAPI(title="Vector Processing Pipeline",
description="Service for text segment processing through embedding, clustering, and planarization")
T = TypeVar('T')
Context = Dict[str, Any]
class ProcessingStage(str, Enum):
"""Processing stages available in the pipeline"""
EMBEDDER = "embedder"
CLUSTERER = "clusterer"
PLANARIZER = "planarizer"
class PipelineRequest(BaseModel):
"""Request model for pipeline processing"""
segment: str
output_fields: List[ProcessingStage]
embedder: str = "default"
clusterer: str = "default"
planarizer: str = "default"
cluster_field: str = "clusters"
x_field: str = "x"
y_field: str = "y"
# Type definitions for processing functions
EmbedderFunc = Callable[[str], List[float]]
ClustererFunc = Callable[[List[float]], List[int]]
PlanarizerFunc = Callable[[List[float]], List[Tuple[float, float]]]
# Function registry stores
Store = Mapping[str, Callable]
Mall = Dict[str, Store]
async def simple_test_embedder(segment: str) -> List[float]:
"""Generate basic vector embeddings from text segment
>>> import asyncio
>>> embeddings = asyncio.run(simple_test_embedder("test"))
>>> len(embeddings) > 0
True
"""
await asyncio.sleep(0.1) # Simulate network call
return [0.1, 0.2, 0.3] # Simplified example
async def simple_test_clusterer(embeddings: List[float]) -> List[int]:
"""Generate basic clusters from embeddings"""
await asyncio.sleep(0.1) # Simulate network call
return [0, 1, 0] # Simplified example
async def simple_test_planarizer(embeddings: List[float]) -> List[Tuple[float, float]]:
"""Generate basic 2D projections from embeddings"""
await asyncio.sleep(0.1) # Simulate network call
return [(1.0, 4.0), (2.0, 5.0), (3.0, 6.0)] # Simplified example
def get_mall() -> Mall:
"""Get the registry mall containing all function stores
Returns:
A dictionary of stores, each containing registered processing functions
"""
return {
"embedders": {
"default": simple_test_embedder,
},
"clusterers": {
"default": simple_test_clusterer,
},
"planarizers": {
"default": simple_test_planarizer,
}
}
def handle_lookup_errors(status_code: int = 404):
"""Decorator to handle KeyError exceptions and convert to HTTP exceptions
Args:
status_code: HTTP status code to use for the exception
Returns:
Decorated function that handles KeyErrors
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except KeyError as e:
message = f"Resource not found: {str(e)}"
raise HTTPException(status_code=status_code, detail=message)
return wrapper
return decorator
async def _get_function_from_store(store: str, key: str, mall: Mall) -> Callable:
"""Get a processing function from the mall
Args:
store: Name of the store ("embedders", "clusterers", "planarizers")
key: Key to look up in the store
mall: The function registry mall
Returns:
The requested function
Raises:
KeyError: If the store or key is not found
"""
if store not in mall:
raise KeyError(f"Store '{store}' not found in mall")
store_obj = mall[store]
if key not in store_obj:
raise KeyError(f"Key '{key}' not found in store '{store}'")
return store_obj[key]
@handle_lookup_errors(status_code=404)
async def process_stage(stage: ProcessingStage, context: Context, mall: Mall) -> Context:
"""Process a single pipeline stage
Args:
stage: The processing stage to execute
context: The current pipeline context
mall: The function registry mall
Returns:
Updated context with stage results
Raises:
HTTPException: If required inputs are missing or lookup fails
"""
if stage == ProcessingStage.EMBEDDER:
if "segment" not in context:
raise HTTPException(status_code=400, detail="Segment required for embedding")
embedder = await _get_function_from_store("embedders", context["embedder"], mall)
context["embeddings"] = await embedder(context["segment"])
elif stage == ProcessingStage.CLUSTERER:
if "embeddings" not in context:
raise HTTPException(status_code=400, detail="Embeddings required for clustering")
clusterer = await _get_function_from_store("clusterers", context["clusterer"], mall)
context["cluster_result"] = await clusterer(context["embeddings"])
elif stage == ProcessingStage.PLANARIZER:
if "embeddings" not in context:
raise HTTPException(status_code=400, detail="Embeddings required for planarization")
planarizer = await _get_function_from_store("planarizers", context["planarizer"], mall)
projections = await planarizer(context["embeddings"])
# Split tuples into separate x and y lists for easier client consumption
x_values, y_values = zip(*projections) if projections else ([], [])
context["planarization_result"] = {
"x": list(x_values),
"y": list(y_values)
}
return context
@handle_lookup_errors(status_code=404)
async def process_pipeline(request: PipelineRequest, mall: Mall = None) -> Dict[str, Any]:
"""Process all requested stages for the pipeline
Args:
request: Pipeline request containing segment and output_fields
mall: Function registry mall (will be fetched if not provided)
Returns:
Dictionary containing results from all requested output fields
Raises:
HTTPException: For any processing errors
"""
if mall is None:
mall = get_mall()
context = {
"segment": request.segment,
"embedder": request.embedder,
"clusterer": request.clusterer,
"planarizer": request.planarizer,
"cluster_field": request.cluster_field,
"x_field": request.x_field,
"y_field": request.y_field
}
# Ensure EMBEDDER is always first if needed
ordered_stages = list(request.output_fields)
if (ProcessingStage.CLUSTERER in ordered_stages or
ProcessingStage.PLANARIZER in ordered_stages):
if ProcessingStage.EMBEDDER not in ordered_stages:
ordered_stages.insert(0, ProcessingStage.EMBEDDER)
# Process each stage in sequence
for stage in ordered_stages:
context = await process_stage(stage, context, mall)
# Only return the requested outputs
result = {}
if ProcessingStage.EMBEDDER in request.output_fields:
result["embeddings"] = context.get("embeddings")
if ProcessingStage.CLUSTERER in request.output_fields:
result[request.cluster_field] = context.get("cluster_result")
if ProcessingStage.PLANARIZER in request.output_fields:
planarization = context.get("planarization_result", {})
result[request.x_field] = planarization.get("x")
result[request.y_field] = planarization.get("y")
return result
# Individual service endpoints
@app.post("/embeddings", response_model=Dict[str, List[float]])
@handle_lookup_errors()
async def embedding_endpoint(segment: str, embedder: str = "default"):
"""Generate embeddings from text segment"""
mall = get_mall()
embedder_func = await _get_function_from_store("embedders", embedder, mall)
embeddings = await embedder_func(segment)
return {"embeddings": embeddings}
@app.post("/clusters")
@handle_lookup_errors()
async def clusters_endpoint(
embeddings: List[float],
clusterer: str = "default",
field_name: str = "clusters"
):
"""Generate clusters from embeddings"""
mall = get_mall()
clusterer_func = await _get_function_from_store("clusterers", clusterer, mall)
clusters = await clusterer_func(embeddings)
return {field_name: clusters}
@app.post("/planarize")
@handle_lookup_errors()
async def planarize_endpoint(
embeddings: List[float],
planarizer: str = "default",
x_field: str = "x",
y_field: str = "y"
):
"""Generate 2D projections from embeddings"""
mall = get_mall()
planarizer_func = await _get_function_from_store("planarizers", planarizer, mall)
projections = await planarizer_func(embeddings)
# Split tuples into separate x and y lists
x_values, y_values = zip(*projections) if projections else ([], [])
return {
x_field: list(x_values),
y_field: list(y_values)
}
# Combined pipeline endpoint
@app.post("/pipeline")
async def pipeline_endpoint(request: PipelineRequest):
"""Process text through a customizable pipeline
The user can specify which fields they want included in the results.
Intermediate results (like embeddings) are computed but not returned
unless specifically requested.
"""
try:
mall = get_mall()
result = await process_pipeline(request, mall)
return result
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
# Convert any other exceptions to HTTP exceptions
raise HTTPException(status_code=500, detail=f"Pipeline processing error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment