Created
March 21, 2025 07:54
-
-
Save thorwhalen/e91fcaee1fc9e0b3dca6d12a9d6fcca2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Vector Processing Pipeline Service | |
| This module provides a FastAPI service for text segment processing through various | |
| stages including embedding, clustering, and dimensionality reduction (planarization). | |
| The service allows for: | |
| - Individual processing of segments through specific stages | |
| - Pipeline processing that combines multiple stages efficiently | |
| - Customizable function selection for each processing stage | |
| - Flexible output field naming | |
| """ | |
| from fastapi import FastAPI, Depends, HTTPException | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Any, Optional, Callable, Mapping, Tuple, Union, TypeVar | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from functools import wraps | |
| import asyncio | |
| app = FastAPI(title="Vector Processing Pipeline", | |
| description="Service for text segment processing through embedding, clustering, and planarization") | |
| T = TypeVar('T') | |
| Context = Dict[str, Any] | |
| class ProcessingStage(str, Enum): | |
| """Processing stages available in the pipeline""" | |
| EMBEDDER = "embedder" | |
| CLUSTERER = "clusterer" | |
| PLANARIZER = "planarizer" | |
| class PipelineRequest(BaseModel): | |
| """Request model for pipeline processing""" | |
| segment: str | |
| output_fields: List[ProcessingStage] | |
| embedder: str = "default" | |
| clusterer: str = "default" | |
| planarizer: str = "default" | |
| cluster_field: str = "clusters" | |
| x_field: str = "x" | |
| y_field: str = "y" | |
| # Type definitions for processing functions | |
| EmbedderFunc = Callable[[str], List[float]] | |
| ClustererFunc = Callable[[List[float]], List[int]] | |
| PlanarizerFunc = Callable[[List[float]], List[Tuple[float, float]]] | |
| # Function registry stores | |
| Store = Mapping[str, Callable] | |
| Mall = Dict[str, Store] | |
| async def simple_test_embedder(segment: str) -> List[float]: | |
| """Generate basic vector embeddings from text segment | |
| >>> import asyncio | |
| >>> embeddings = asyncio.run(simple_test_embedder("test")) | |
| >>> len(embeddings) > 0 | |
| True | |
| """ | |
| await asyncio.sleep(0.1) # Simulate network call | |
| return [0.1, 0.2, 0.3] # Simplified example | |
| async def simple_test_clusterer(embeddings: List[float]) -> List[int]: | |
| """Generate basic clusters from embeddings""" | |
| await asyncio.sleep(0.1) # Simulate network call | |
| return [0, 1, 0] # Simplified example | |
| async def simple_test_planarizer(embeddings: List[float]) -> List[Tuple[float, float]]: | |
| """Generate basic 2D projections from embeddings""" | |
| await asyncio.sleep(0.1) # Simulate network call | |
| return [(1.0, 4.0), (2.0, 5.0), (3.0, 6.0)] # Simplified example | |
| def get_mall() -> Mall: | |
| """Get the registry mall containing all function stores | |
| Returns: | |
| A dictionary of stores, each containing registered processing functions | |
| """ | |
| return { | |
| "embedders": { | |
| "default": simple_test_embedder, | |
| }, | |
| "clusterers": { | |
| "default": simple_test_clusterer, | |
| }, | |
| "planarizers": { | |
| "default": simple_test_planarizer, | |
| } | |
| } | |
| def handle_lookup_errors(status_code: int = 404): | |
| """Decorator to handle KeyError exceptions and convert to HTTP exceptions | |
| Args: | |
| status_code: HTTP status code to use for the exception | |
| Returns: | |
| Decorated function that handles KeyErrors | |
| """ | |
| def decorator(func): | |
| @wraps(func) | |
| async def wrapper(*args, **kwargs): | |
| try: | |
| return await func(*args, **kwargs) | |
| except KeyError as e: | |
| message = f"Resource not found: {str(e)}" | |
| raise HTTPException(status_code=status_code, detail=message) | |
| return wrapper | |
| return decorator | |
| async def _get_function_from_store(store: str, key: str, mall: Mall) -> Callable: | |
| """Get a processing function from the mall | |
| Args: | |
| store: Name of the store ("embedders", "clusterers", "planarizers") | |
| key: Key to look up in the store | |
| mall: The function registry mall | |
| Returns: | |
| The requested function | |
| Raises: | |
| KeyError: If the store or key is not found | |
| """ | |
| if store not in mall: | |
| raise KeyError(f"Store '{store}' not found in mall") | |
| store_obj = mall[store] | |
| if key not in store_obj: | |
| raise KeyError(f"Key '{key}' not found in store '{store}'") | |
| return store_obj[key] | |
| @handle_lookup_errors(status_code=404) | |
| async def process_stage(stage: ProcessingStage, context: Context, mall: Mall) -> Context: | |
| """Process a single pipeline stage | |
| Args: | |
| stage: The processing stage to execute | |
| context: The current pipeline context | |
| mall: The function registry mall | |
| Returns: | |
| Updated context with stage results | |
| Raises: | |
| HTTPException: If required inputs are missing or lookup fails | |
| """ | |
| if stage == ProcessingStage.EMBEDDER: | |
| if "segment" not in context: | |
| raise HTTPException(status_code=400, detail="Segment required for embedding") | |
| embedder = await _get_function_from_store("embedders", context["embedder"], mall) | |
| context["embeddings"] = await embedder(context["segment"]) | |
| elif stage == ProcessingStage.CLUSTERER: | |
| if "embeddings" not in context: | |
| raise HTTPException(status_code=400, detail="Embeddings required for clustering") | |
| clusterer = await _get_function_from_store("clusterers", context["clusterer"], mall) | |
| context["cluster_result"] = await clusterer(context["embeddings"]) | |
| elif stage == ProcessingStage.PLANARIZER: | |
| if "embeddings" not in context: | |
| raise HTTPException(status_code=400, detail="Embeddings required for planarization") | |
| planarizer = await _get_function_from_store("planarizers", context["planarizer"], mall) | |
| projections = await planarizer(context["embeddings"]) | |
| # Split tuples into separate x and y lists for easier client consumption | |
| x_values, y_values = zip(*projections) if projections else ([], []) | |
| context["planarization_result"] = { | |
| "x": list(x_values), | |
| "y": list(y_values) | |
| } | |
| return context | |
| @handle_lookup_errors(status_code=404) | |
| async def process_pipeline(request: PipelineRequest, mall: Mall = None) -> Dict[str, Any]: | |
| """Process all requested stages for the pipeline | |
| Args: | |
| request: Pipeline request containing segment and output_fields | |
| mall: Function registry mall (will be fetched if not provided) | |
| Returns: | |
| Dictionary containing results from all requested output fields | |
| Raises: | |
| HTTPException: For any processing errors | |
| """ | |
| if mall is None: | |
| mall = get_mall() | |
| context = { | |
| "segment": request.segment, | |
| "embedder": request.embedder, | |
| "clusterer": request.clusterer, | |
| "planarizer": request.planarizer, | |
| "cluster_field": request.cluster_field, | |
| "x_field": request.x_field, | |
| "y_field": request.y_field | |
| } | |
| # Ensure EMBEDDER is always first if needed | |
| ordered_stages = list(request.output_fields) | |
| if (ProcessingStage.CLUSTERER in ordered_stages or | |
| ProcessingStage.PLANARIZER in ordered_stages): | |
| if ProcessingStage.EMBEDDER not in ordered_stages: | |
| ordered_stages.insert(0, ProcessingStage.EMBEDDER) | |
| # Process each stage in sequence | |
| for stage in ordered_stages: | |
| context = await process_stage(stage, context, mall) | |
| # Only return the requested outputs | |
| result = {} | |
| if ProcessingStage.EMBEDDER in request.output_fields: | |
| result["embeddings"] = context.get("embeddings") | |
| if ProcessingStage.CLUSTERER in request.output_fields: | |
| result[request.cluster_field] = context.get("cluster_result") | |
| if ProcessingStage.PLANARIZER in request.output_fields: | |
| planarization = context.get("planarization_result", {}) | |
| result[request.x_field] = planarization.get("x") | |
| result[request.y_field] = planarization.get("y") | |
| return result | |
| # Individual service endpoints | |
| @app.post("/embeddings", response_model=Dict[str, List[float]]) | |
| @handle_lookup_errors() | |
| async def embedding_endpoint(segment: str, embedder: str = "default"): | |
| """Generate embeddings from text segment""" | |
| mall = get_mall() | |
| embedder_func = await _get_function_from_store("embedders", embedder, mall) | |
| embeddings = await embedder_func(segment) | |
| return {"embeddings": embeddings} | |
| @app.post("/clusters") | |
| @handle_lookup_errors() | |
| async def clusters_endpoint( | |
| embeddings: List[float], | |
| clusterer: str = "default", | |
| field_name: str = "clusters" | |
| ): | |
| """Generate clusters from embeddings""" | |
| mall = get_mall() | |
| clusterer_func = await _get_function_from_store("clusterers", clusterer, mall) | |
| clusters = await clusterer_func(embeddings) | |
| return {field_name: clusters} | |
| @app.post("/planarize") | |
| @handle_lookup_errors() | |
| async def planarize_endpoint( | |
| embeddings: List[float], | |
| planarizer: str = "default", | |
| x_field: str = "x", | |
| y_field: str = "y" | |
| ): | |
| """Generate 2D projections from embeddings""" | |
| mall = get_mall() | |
| planarizer_func = await _get_function_from_store("planarizers", planarizer, mall) | |
| projections = await planarizer_func(embeddings) | |
| # Split tuples into separate x and y lists | |
| x_values, y_values = zip(*projections) if projections else ([], []) | |
| return { | |
| x_field: list(x_values), | |
| y_field: list(y_values) | |
| } | |
| # Combined pipeline endpoint | |
| @app.post("/pipeline") | |
| async def pipeline_endpoint(request: PipelineRequest): | |
| """Process text through a customizable pipeline | |
| The user can specify which fields they want included in the results. | |
| Intermediate results (like embeddings) are computed but not returned | |
| unless specifically requested. | |
| """ | |
| try: | |
| mall = get_mall() | |
| result = await process_pipeline(request, mall) | |
| return result | |
| except HTTPException: | |
| # Re-raise HTTP exceptions | |
| raise | |
| except Exception as e: | |
| # Convert any other exceptions to HTTP exceptions | |
| raise HTTPException(status_code=500, detail=f"Pipeline processing error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment