Skip to content

Instantly share code, notes, and snippets.

@epireve
Last active August 15, 2024 04:24
Show Gist options
  • Select an option

  • Save epireve/6c48ddc50fae29a9c171eafcdb510b30 to your computer and use it in GitHub Desktop.

Select an option

Save epireve/6c48ddc50fae29a9c171eafcdb510b30 to your computer and use it in GitHub Desktop.
The Fault-Tolerant LLM Gateway is a robust and reliable architecture designed to ensure high availability and redundancy in LLM deployments. This project aims to develop a gateway that can intelligently route requests to multiple LLM providers, detect failures, and automatically fallback to alternative models or providers in case of errors or un…
import logging
class LLMRouterException(Exception):
pass
class ProviderFailedException(LLMRouterException):
pass
class AllProvidersFailedException(LLMRouterException):
pass
def setup_logging():
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
filename='llm_router.log'
)
def log_error(error_message: str):
logging.error(error_message)
def log_info(info_message: str):
logging.info(info_message)
from model_config import TaskConfig, ModelConfig
from llm_provider import LLMProvider, OpenAIProvider, GroqProvider, DeepseekProvider
class FallbackManager:
def __init__(self, task_config: TaskConfig):
self.task_config = task_config
self.providers = self._initialize_providers()
def _initialize_providers(self) -> List[LLMProvider]:
providers = []
for config in [self.task_config.primary, self.task_config.secondary, self.task_config.tertiary, self.task_config.quaternary]:
if config.provider == "OpenAI":
providers.append(OpenAIProvider(config.api_key, config.model_name))
elif config.provider == "Groq":
providers.append(GroqProvider(config.api_key, config.model_name))
elif config.provider == "Deepseek":
providers.append(DeepseekProvider(config.api_key, config.model_name))
return providers
async def execute_with_fallback(self, prompt: str) -> str:
for provider in self.providers:
try:
if await provider.is_healthy():
return await provider.generate_response(prompt)
except Exception as e:
print(f"Provider {provider.__class__.__name__} failed: {str(e)}")
raise Exception("All providers failed to generate a response.")
from abc import ABC, abstractmethod
import aiohttp
import asyncio
import json
class LLMProvider(ABC):
@abstractmethod
async def generate_response(self, prompt: str) -> str:
pass
@abstractmethod
async def is_healthy(self) -> bool:
pass
class OpenAIProvider(LLMProvider):
def __init__(self, api_key: str, model_name: str):
self.api_key = api_key
self.model_name = model_name
self.base_url = "https://api.openai.com/v1"
async def generate_response(self, prompt: str) -> str:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}]
}
async with session.post(f"{self.base_url}/chat/completions", headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
return result['choices'][0]['message']['content']
else:
raise Exception(f"OpenAI API error: {response.status}")
async def is_healthy(self) -> bool:
try:
async with aiohttp.ClientSession() as session:
headers = {"Authorization": f"Bearer {self.api_key}"}
async with session.get(f"{self.base_url}/models", headers=headers) as response:
return response.status == 200
except:
return False
class GroqProvider(LLMProvider):
def __init__(self, api_key: str, model_name: str):
self.api_key = api_key
self.model_name = model_name
self.base_url = "https://api.groq.com/v1"
async def generate_response(self, prompt: str) -> str:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}]
}
async with session.post(f"{self.base_url}/chat/completions", headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
return result['choices'][0]['message']['content']
else:
raise Exception(f"Groq API error: {response.status}")
async def is_healthy(self) -> bool:
try:
async with aiohttp.ClientSession() as session:
headers = {"Authorization": f"Bearer {self.api_key}"}
async with session.get(f"{self.base_url}/models", headers=headers) as response:
return response.status == 200
except:
return False
class DeepseekProvider(LLMProvider):
def __init__(self, api_key: str, model_name: str):
self.api_key = api_key
self.model_name = model_name
self.base_url = "https://api.deepseek.com/v1" # Mock URL
async def generate_response(self, prompt: str) -> str:
# Mock implementation
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model_name,
"prompt": prompt
}
async with session.post(f"{self.base_url}/chat", headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
return result['response']
else:
raise Exception(f"Deepseek API error: {response.status}")
async def is_healthy(self) -> bool:
# Mock implementation
try:
async with aiohttp.ClientSession() as session:
headers = {"Authorization": f"Bearer {self.api_key}"}
async with session.get(f"{self.base_url}/status", headers=headers) as response:
return response.status == 200
except:
return False
import asyncio
from model_config import ConfigManager, TaskConfig, ModelConfig
from fallback_manager import FallbackManager
from error_handling import setup_logging, log_error, log_info, AllProvidersFailedException
async def main():
setup_logging()
# Set up configurations
config_manager = ConfigManager()
config_manager.add_task(TaskConfig(
task_name="chat_with_knowledge",
primary=ModelConfig("OpenAI", "GPT-4o", "openai_api_key_1", "account_1"),
secondary=ModelConfig("Groq", "llama2-70b-4096", "groq_api_key", "account_1"),
tertiary=ModelConfig("OpenAI", "GPT-4o", "openai_api_key_2", "account_2"),
quaternary=ModelConfig("Groq", "mixtral-8x7b-32768", "groq_api_key", "account_1")
))
# Get task configuration
task_config = config_manager.get_task_config("chat_with_knowledge")
# Initialize fallback manager
fallback_manager = FallbackManager(task_config)
# Test the fallback mechanism
prompt = "What is the capital of France?"
try:
response = await fallback_manager.execute_with_fallback(prompt)
log_info(f"Response: {response}")
except AllProvidersFailedException as e:
log_error(f"All providers failed: {str(e)}")
if __name__ == "__main__":
asyncio.run(main())
from dataclasses import dataclass
from typing import List
@dataclass
class ModelConfig:
provider: str
model_name: str
api_key: str
account_id: str
@dataclass
class TaskConfig:
task_name: str
primary: ModelConfig
secondary: ModelConfig
tertiary: ModelConfig
quaternary: ModelConfig
class ConfigManager:
def __init__(self):
self.tasks: List[TaskConfig] = []
def add_task(self, task_config: TaskConfig):
self.tasks.append(task_config)
def get_task_config(self, task_name: str) -> TaskConfig:
for task in self.tasks:
if task.task_name == task_name:
return task
raise ValueError(f"Task '{task_name}' not found in configuration")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment