Skip to content

Instantly share code, notes, and snippets.

@haijohn
Created January 21, 2026 15:39
Show Gist options
  • Select an option

  • Save haijohn/9b1080a8bb1ff844a941bc1291276893 to your computer and use it in GitHub Desktop.

Select an option

Save haijohn/9b1080a8bb1ff844a941bc1291276893 to your computer and use it in GitHub Desktop.
ttlcache
import time
import threading
import unittest
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
from unittest.mock import patch
# Third-party dependencies
# pip install fastapi uvicorn httpx starlette
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.concurrency import run_in_threadpool
# ==========================================
# PART 1: Custom TTL Cache Implementation
# ==========================================
class SimpleTTLCache:
"""
A simple thread-safe LRU Cache with Time-To-Live (TTL) expiration.
Implemented using OrderedDict to track usage order.
"""
def __init__(self, ttl: int = 60, maxsize: int = 1000):
self.ttl = ttl
self.maxsize = maxsize
# Store: key -> (value, expire_at_timestamp)
self.cache: OrderedDict[str, Tuple[Any, float]] = OrderedDict()
self.lock = threading.Lock()
def get(self, key: str) -> Optional[Any]:
with self.lock:
if key not in self.cache:
return None
value, expire_at = self.cache[key]
# Lazy Expiration: Check time upon access
if time.time() > expire_at:
del self.cache[key]
return None
# LRU Logic: Move accessed item to the end (most recently used)
self.cache.move_to_end(key)
return value
def set(self, key: str, value: Any):
with self.lock:
# If key exists, move to end to update usage
if key in self.cache:
self.cache.move_to_end(key)
expire_at = time.time() + self.ttl
self.cache[key] = (value, expire_at)
# Eviction Logic: If size exceeds limit, remove the oldest item (FIFO/LRU)
if len(self.cache) > self.maxsize:
self.cache.popitem(last=False)
def clear(self):
with self.lock:
self.cache.clear()
def __len__(self):
with self.lock:
return len(self.cache)
# ==========================================
# PART 2: Service Layer (Business Logic)
# ==========================================
class GroupService:
def __init__(self, ttl: int = 60):
# Use our custom cache
self.cache = SimpleTTLCache(ttl=ttl, maxsize=1000)
# Lock to prevent "Cache Stampede" (multiple threads hitting API simultaneously)
self.service_lock = threading.Lock()
def _fetch_from_graph_api(self, email: str) -> List[str]:
"""
Simulates a BLOCKING synchronous I/O request (e.g., using requests library).
"""
# Simulate network latency (0.05s)
time.sleep(0.05)
# Simulate business logic
return ["admin", "editor"] if "admin" in email else ["viewer"]
def get_groups(self, email: str) -> List[str]:
# 1. Fast path: Check cache
cached = self.cache.get(email)
if cached is not None:
return cached
# 2. Slow path: Acquire lock to fetch data
with self.service_lock:
# Double-check locking (in case another thread filled cache while we waited)
cached = self.cache.get(email)
if cached is not None:
return cached
# 3. Perform blocking I/O
print(f"[Service] Fetching groups for {email} from remote Graph...")
groups = self._fetch_from_graph_api(email)
# 4. Save to cache
self.cache.set(email, groups)
return groups
# ==========================================
# PART 3: Middleware (The Bridge)
# ==========================================
class AuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, group_service: GroupService):
super().__init__(app)
self.group_service = group_service
async def dispatch(self, request: Request, call_next):
# Extract user identifier (e.g., from Token)
# Simplified here to use a Header
user_email = request.headers.get("X-User-Email")
if user_email:
# IMPORTANT: Since `get_groups` contains blocking I/O (time.sleep),
# we MUST run it in a threadpool to avoid blocking the FastAPI Event Loop.
groups = await run_in_threadpool(self.group_service.get_groups, user_email)
# Inject data into request state
request.state.user_email = user_email
request.state.user_groups = groups
response = await call_next(request)
return response
# ==========================================
# PART 4: Unit Tests
# ==========================================
class TestAppAndCache(unittest.TestCase):
def setUp(self):
"""Setup a fresh App and Service for every test case."""
self.service = GroupService(ttl=2) # Short TTL for testing
self.app = FastAPI()
self.app.add_middleware(AuthMiddleware, group_service=self.service)
@self.app.get("/me")
def read_me(request: Request):
# Endpoint expects middleware to have populated state
if not hasattr(request.state, "user_groups"):
return {"error": "unauthorized"}
return {
"email": request.state.user_email,
"groups": request.state.user_groups
}
self.client = TestClient(self.app)
def test_custom_cache_lru(self):
"""Test the internal logic of SimpleTTLCache (LRU Eviction)."""
cache = SimpleTTLCache(maxsize=2, ttl=10)
cache.set("A", 1)
cache.set("B", 2)
# Access A to make it 'recently used'
cache.get("A")
# Add C, which should push out B (least recently used)
cache.set("C", 3)
self.assertEqual(cache.get("A"), 1)
self.assertIsNone(cache.get("B")) # B should be evicted
self.assertEqual(cache.get("C"), 3)
def test_custom_cache_ttl(self):
"""Test the internal logic of SimpleTTLCache (Expiration)."""
cache = SimpleTTLCache(ttl=1) # 1 second TTL
cache.set("A", 1)
self.assertEqual(cache.get("A"), 1)
time.sleep(1.1)
self.assertIsNone(cache.get("A")) # Should be expired
def test_middleware_injection_and_caching(self):
"""
Integration Test:
1. Middleware injects data.
2. Service caches data (network is called only once).
"""
email = "admin@company.com"
headers = {"X-User-Email": email}
# We spy on the _fetch_from_graph_api method to count calls
with patch.object(self.service, '_fetch_from_graph_api', side_effect=self.service._fetch_from_graph_api) as mock_fetch:
# --- Request 1 ---
resp1 = self.client.get("/me", headers=headers)
self.assertEqual(resp1.status_code, 200)
self.assertEqual(resp1.json()["groups"], ["admin", "editor"])
# Verify fetch was called ONCE
mock_fetch.assert_called_once_with(email)
# Reset the mock counter
mock_fetch.reset_mock()
# --- Request 2 ---
resp2 = self.client.get("/me", headers=headers)
self.assertEqual(resp2.status_code, 200)
# Verify fetch was NOT called (Cache Hit)
mock_fetch.assert_not_called()
# ==========================================
# Execution Entry Point
# ==========================================
if __name__ == "__main__":
print("Running Tests...")
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment