Created
January 21, 2026 15:39
-
-
Save haijohn/9b1080a8bb1ff844a941bc1291276893 to your computer and use it in GitHub Desktop.
ttlcache
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| import threading | |
| import unittest | |
| from collections import OrderedDict | |
| from typing import Any, List, Optional, Tuple | |
| from unittest.mock import patch | |
| # Third-party dependencies | |
| # pip install fastapi uvicorn httpx starlette | |
| from fastapi import FastAPI, Request | |
| from fastapi.testclient import TestClient | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.concurrency import run_in_threadpool | |
| # ========================================== | |
| # PART 1: Custom TTL Cache Implementation | |
| # ========================================== | |
| class SimpleTTLCache: | |
| """ | |
| A simple thread-safe LRU Cache with Time-To-Live (TTL) expiration. | |
| Implemented using OrderedDict to track usage order. | |
| """ | |
| def __init__(self, ttl: int = 60, maxsize: int = 1000): | |
| self.ttl = ttl | |
| self.maxsize = maxsize | |
| # Store: key -> (value, expire_at_timestamp) | |
| self.cache: OrderedDict[str, Tuple[Any, float]] = OrderedDict() | |
| self.lock = threading.Lock() | |
| def get(self, key: str) -> Optional[Any]: | |
| with self.lock: | |
| if key not in self.cache: | |
| return None | |
| value, expire_at = self.cache[key] | |
| # Lazy Expiration: Check time upon access | |
| if time.time() > expire_at: | |
| del self.cache[key] | |
| return None | |
| # LRU Logic: Move accessed item to the end (most recently used) | |
| self.cache.move_to_end(key) | |
| return value | |
| def set(self, key: str, value: Any): | |
| with self.lock: | |
| # If key exists, move to end to update usage | |
| if key in self.cache: | |
| self.cache.move_to_end(key) | |
| expire_at = time.time() + self.ttl | |
| self.cache[key] = (value, expire_at) | |
| # Eviction Logic: If size exceeds limit, remove the oldest item (FIFO/LRU) | |
| if len(self.cache) > self.maxsize: | |
| self.cache.popitem(last=False) | |
| def clear(self): | |
| with self.lock: | |
| self.cache.clear() | |
| def __len__(self): | |
| with self.lock: | |
| return len(self.cache) | |
| # ========================================== | |
| # PART 2: Service Layer (Business Logic) | |
| # ========================================== | |
| class GroupService: | |
| def __init__(self, ttl: int = 60): | |
| # Use our custom cache | |
| self.cache = SimpleTTLCache(ttl=ttl, maxsize=1000) | |
| # Lock to prevent "Cache Stampede" (multiple threads hitting API simultaneously) | |
| self.service_lock = threading.Lock() | |
| def _fetch_from_graph_api(self, email: str) -> List[str]: | |
| """ | |
| Simulates a BLOCKING synchronous I/O request (e.g., using requests library). | |
| """ | |
| # Simulate network latency (0.05s) | |
| time.sleep(0.05) | |
| # Simulate business logic | |
| return ["admin", "editor"] if "admin" in email else ["viewer"] | |
| def get_groups(self, email: str) -> List[str]: | |
| # 1. Fast path: Check cache | |
| cached = self.cache.get(email) | |
| if cached is not None: | |
| return cached | |
| # 2. Slow path: Acquire lock to fetch data | |
| with self.service_lock: | |
| # Double-check locking (in case another thread filled cache while we waited) | |
| cached = self.cache.get(email) | |
| if cached is not None: | |
| return cached | |
| # 3. Perform blocking I/O | |
| print(f"[Service] Fetching groups for {email} from remote Graph...") | |
| groups = self._fetch_from_graph_api(email) | |
| # 4. Save to cache | |
| self.cache.set(email, groups) | |
| return groups | |
| # ========================================== | |
| # PART 3: Middleware (The Bridge) | |
| # ========================================== | |
| class AuthMiddleware(BaseHTTPMiddleware): | |
| def __init__(self, app, group_service: GroupService): | |
| super().__init__(app) | |
| self.group_service = group_service | |
| async def dispatch(self, request: Request, call_next): | |
| # Extract user identifier (e.g., from Token) | |
| # Simplified here to use a Header | |
| user_email = request.headers.get("X-User-Email") | |
| if user_email: | |
| # IMPORTANT: Since `get_groups` contains blocking I/O (time.sleep), | |
| # we MUST run it in a threadpool to avoid blocking the FastAPI Event Loop. | |
| groups = await run_in_threadpool(self.group_service.get_groups, user_email) | |
| # Inject data into request state | |
| request.state.user_email = user_email | |
| request.state.user_groups = groups | |
| response = await call_next(request) | |
| return response | |
| # ========================================== | |
| # PART 4: Unit Tests | |
| # ========================================== | |
| class TestAppAndCache(unittest.TestCase): | |
| def setUp(self): | |
| """Setup a fresh App and Service for every test case.""" | |
| self.service = GroupService(ttl=2) # Short TTL for testing | |
| self.app = FastAPI() | |
| self.app.add_middleware(AuthMiddleware, group_service=self.service) | |
| @self.app.get("/me") | |
| def read_me(request: Request): | |
| # Endpoint expects middleware to have populated state | |
| if not hasattr(request.state, "user_groups"): | |
| return {"error": "unauthorized"} | |
| return { | |
| "email": request.state.user_email, | |
| "groups": request.state.user_groups | |
| } | |
| self.client = TestClient(self.app) | |
| def test_custom_cache_lru(self): | |
| """Test the internal logic of SimpleTTLCache (LRU Eviction).""" | |
| cache = SimpleTTLCache(maxsize=2, ttl=10) | |
| cache.set("A", 1) | |
| cache.set("B", 2) | |
| # Access A to make it 'recently used' | |
| cache.get("A") | |
| # Add C, which should push out B (least recently used) | |
| cache.set("C", 3) | |
| self.assertEqual(cache.get("A"), 1) | |
| self.assertIsNone(cache.get("B")) # B should be evicted | |
| self.assertEqual(cache.get("C"), 3) | |
| def test_custom_cache_ttl(self): | |
| """Test the internal logic of SimpleTTLCache (Expiration).""" | |
| cache = SimpleTTLCache(ttl=1) # 1 second TTL | |
| cache.set("A", 1) | |
| self.assertEqual(cache.get("A"), 1) | |
| time.sleep(1.1) | |
| self.assertIsNone(cache.get("A")) # Should be expired | |
| def test_middleware_injection_and_caching(self): | |
| """ | |
| Integration Test: | |
| 1. Middleware injects data. | |
| 2. Service caches data (network is called only once). | |
| """ | |
| email = "admin@company.com" | |
| headers = {"X-User-Email": email} | |
| # We spy on the _fetch_from_graph_api method to count calls | |
| with patch.object(self.service, '_fetch_from_graph_api', side_effect=self.service._fetch_from_graph_api) as mock_fetch: | |
| # --- Request 1 --- | |
| resp1 = self.client.get("/me", headers=headers) | |
| self.assertEqual(resp1.status_code, 200) | |
| self.assertEqual(resp1.json()["groups"], ["admin", "editor"]) | |
| # Verify fetch was called ONCE | |
| mock_fetch.assert_called_once_with(email) | |
| # Reset the mock counter | |
| mock_fetch.reset_mock() | |
| # --- Request 2 --- | |
| resp2 = self.client.get("/me", headers=headers) | |
| self.assertEqual(resp2.status_code, 200) | |
| # Verify fetch was NOT called (Cache Hit) | |
| mock_fetch.assert_not_called() | |
| # ========================================== | |
| # Execution Entry Point | |
| # ========================================== | |
| if __name__ == "__main__": | |
| print("Running Tests...") | |
| unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment