Last active
August 15, 2024 04:24
-
-
Save epireve/6c48ddc50fae29a9c171eafcdb510b30 to your computer and use it in GitHub Desktop.
The Fault-Tolerant LLM Gateway is a robust and reliable architecture designed to ensure high availability and redundancy in LLM deployments. This project aims to develop a gateway that can intelligently route requests to multiple LLM providers, detect failures, and automatically fallback to alternative models or providers in case of errors or un…
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| class LLMRouterException(Exception): | |
| pass | |
| class ProviderFailedException(LLMRouterException): | |
| pass | |
| class AllProvidersFailedException(LLMRouterException): | |
| pass | |
| def setup_logging(): | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| filename='llm_router.log' | |
| ) | |
| def log_error(error_message: str): | |
| logging.error(error_message) | |
| def log_info(info_message: str): | |
| logging.info(info_message) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from model_config import TaskConfig, ModelConfig | |
| from llm_provider import LLMProvider, OpenAIProvider, GroqProvider, DeepseekProvider | |
| class FallbackManager: | |
| def __init__(self, task_config: TaskConfig): | |
| self.task_config = task_config | |
| self.providers = self._initialize_providers() | |
| def _initialize_providers(self) -> List[LLMProvider]: | |
| providers = [] | |
| for config in [self.task_config.primary, self.task_config.secondary, self.task_config.tertiary, self.task_config.quaternary]: | |
| if config.provider == "OpenAI": | |
| providers.append(OpenAIProvider(config.api_key, config.model_name)) | |
| elif config.provider == "Groq": | |
| providers.append(GroqProvider(config.api_key, config.model_name)) | |
| elif config.provider == "Deepseek": | |
| providers.append(DeepseekProvider(config.api_key, config.model_name)) | |
| return providers | |
| async def execute_with_fallback(self, prompt: str) -> str: | |
| for provider in self.providers: | |
| try: | |
| if await provider.is_healthy(): | |
| return await provider.generate_response(prompt) | |
| except Exception as e: | |
| print(f"Provider {provider.__class__.__name__} failed: {str(e)}") | |
| raise Exception("All providers failed to generate a response.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from abc import ABC, abstractmethod | |
| import aiohttp | |
| import asyncio | |
| import json | |
| class LLMProvider(ABC): | |
| @abstractmethod | |
| async def generate_response(self, prompt: str) -> str: | |
| pass | |
| @abstractmethod | |
| async def is_healthy(self) -> bool: | |
| pass | |
| class OpenAIProvider(LLMProvider): | |
| def __init__(self, api_key: str, model_name: str): | |
| self.api_key = api_key | |
| self.model_name = model_name | |
| self.base_url = "https://api.openai.com/v1" | |
| async def generate_response(self, prompt: str) -> str: | |
| async with aiohttp.ClientSession() as session: | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": self.model_name, | |
| "messages": [{"role": "user", "content": prompt}] | |
| } | |
| async with session.post(f"{self.base_url}/chat/completions", headers=headers, json=data) as response: | |
| if response.status == 200: | |
| result = await response.json() | |
| return result['choices'][0]['message']['content'] | |
| else: | |
| raise Exception(f"OpenAI API error: {response.status}") | |
| async def is_healthy(self) -> bool: | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| headers = {"Authorization": f"Bearer {self.api_key}"} | |
| async with session.get(f"{self.base_url}/models", headers=headers) as response: | |
| return response.status == 200 | |
| except: | |
| return False | |
| class GroqProvider(LLMProvider): | |
| def __init__(self, api_key: str, model_name: str): | |
| self.api_key = api_key | |
| self.model_name = model_name | |
| self.base_url = "https://api.groq.com/v1" | |
| async def generate_response(self, prompt: str) -> str: | |
| async with aiohttp.ClientSession() as session: | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": self.model_name, | |
| "messages": [{"role": "user", "content": prompt}] | |
| } | |
| async with session.post(f"{self.base_url}/chat/completions", headers=headers, json=data) as response: | |
| if response.status == 200: | |
| result = await response.json() | |
| return result['choices'][0]['message']['content'] | |
| else: | |
| raise Exception(f"Groq API error: {response.status}") | |
| async def is_healthy(self) -> bool: | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| headers = {"Authorization": f"Bearer {self.api_key}"} | |
| async with session.get(f"{self.base_url}/models", headers=headers) as response: | |
| return response.status == 200 | |
| except: | |
| return False | |
| class DeepseekProvider(LLMProvider): | |
| def __init__(self, api_key: str, model_name: str): | |
| self.api_key = api_key | |
| self.model_name = model_name | |
| self.base_url = "https://api.deepseek.com/v1" # Mock URL | |
| async def generate_response(self, prompt: str) -> str: | |
| # Mock implementation | |
| async with aiohttp.ClientSession() as session: | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": self.model_name, | |
| "prompt": prompt | |
| } | |
| async with session.post(f"{self.base_url}/chat", headers=headers, json=data) as response: | |
| if response.status == 200: | |
| result = await response.json() | |
| return result['response'] | |
| else: | |
| raise Exception(f"Deepseek API error: {response.status}") | |
| async def is_healthy(self) -> bool: | |
| # Mock implementation | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| headers = {"Authorization": f"Bearer {self.api_key}"} | |
| async with session.get(f"{self.base_url}/status", headers=headers) as response: | |
| return response.status == 200 | |
| except: | |
| return False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| from model_config import ConfigManager, TaskConfig, ModelConfig | |
| from fallback_manager import FallbackManager | |
| from error_handling import setup_logging, log_error, log_info, AllProvidersFailedException | |
| async def main(): | |
| setup_logging() | |
| # Set up configurations | |
| config_manager = ConfigManager() | |
| config_manager.add_task(TaskConfig( | |
| task_name="chat_with_knowledge", | |
| primary=ModelConfig("OpenAI", "GPT-4o", "openai_api_key_1", "account_1"), | |
| secondary=ModelConfig("Groq", "llama2-70b-4096", "groq_api_key", "account_1"), | |
| tertiary=ModelConfig("OpenAI", "GPT-4o", "openai_api_key_2", "account_2"), | |
| quaternary=ModelConfig("Groq", "mixtral-8x7b-32768", "groq_api_key", "account_1") | |
| )) | |
| # Get task configuration | |
| task_config = config_manager.get_task_config("chat_with_knowledge") | |
| # Initialize fallback manager | |
| fallback_manager = FallbackManager(task_config) | |
| # Test the fallback mechanism | |
| prompt = "What is the capital of France?" | |
| try: | |
| response = await fallback_manager.execute_with_fallback(prompt) | |
| log_info(f"Response: {response}") | |
| except AllProvidersFailedException as e: | |
| log_error(f"All providers failed: {str(e)}") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from typing import List | |
| @dataclass | |
| class ModelConfig: | |
| provider: str | |
| model_name: str | |
| api_key: str | |
| account_id: str | |
| @dataclass | |
| class TaskConfig: | |
| task_name: str | |
| primary: ModelConfig | |
| secondary: ModelConfig | |
| tertiary: ModelConfig | |
| quaternary: ModelConfig | |
| class ConfigManager: | |
| def __init__(self): | |
| self.tasks: List[TaskConfig] = [] | |
| def add_task(self, task_config: TaskConfig): | |
| self.tasks.append(task_config) | |
| def get_task_config(self, task_name: str) -> TaskConfig: | |
| for task in self.tasks: | |
| if task.task_name == task_name: | |
| return task | |
| raise ValueError(f"Task '{task_name}' not found in configuration") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment