Skip to content

Instantly share code, notes, and snippets.

@KennyVaneetvelde
Created September 15, 2024 14:29
Show Gist options
  • Select an option

  • Save KennyVaneetvelde/c41ca30b7e18bb75cf924a9dd5a2632c to your computer and use it in GitHub Desktop.

Select an option

Save KennyVaneetvelde/c41ca30b7e18bb75cf924a9dd5a2632c to your computer and use it in GitHub Desktop.
Persistently caching pure functions with a custom memoization decorator
import functools
import sqlite3
from typing import Callable, Dict, TypeVar, ParamSpec, List
import pickle
from rich.console import Console
from rich.table import Table
from rich.prompt import Prompt
from rich import box
from tqdm import tqdm
T = TypeVar("T")
P = ParamSpec("P")
console = Console()
# Mock database for caching with persistent storage
class MockDB:
"""
Initialize the mock database.
Args:
filepath (str): Path to the cache file.
"""
def __init__(self, filepath: str = "cache.db") -> None:
"""
Initialize the mock database using SQLite. Creates a table if it doesn't exist.
Args:
filepath (str): Path to the SQLite database file.
"""
self.filepath = filepath
self.conn = sqlite3.connect(self.filepath)
self.cursor = self.conn.cursor()
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value BLOB
)
"""
)
self.conn.commit()
def get(self, key: str) -> T | None:
"""
Retrieve a value from the cache by key.
Args:
key (str): The key to retrieve the value for.
Returns:
T | None: The value associated with the key, or None if not found.
"""
self.cursor.execute("SELECT value FROM cache WHERE key = ?", (key,))
result = self.cursor.fetchone()
return pickle.loads(result[0]) if result else None
def set(self, key: str, value: T) -> None:
"""
Store a key-value pair in the cache and persist it to the SQLite database.
Args:
key (str): The key under which the value is stored.
value (T): The value to store.
"""
serialized_value = pickle.dumps(value)
self.cursor.execute(
"REPLACE INTO cache (key, value) VALUES (?, ?)", (key, serialized_value)
)
self.conn.commit()
def get_all_keys(self) -> List[str]:
"""
Retrieve all cache keys from the database.
Returns:
List[str]: A list of all keys stored in the cache.
"""
self.cursor.execute("SELECT key FROM cache")
return [row[0] for row in self.cursor.fetchall()]
db = MockDB()
def memoize(func: Callable[P, T]) -> Callable[P, T]:
"""
A decorator that caches the results of a function using a mock database.
Args:
func (Callable[P, T]): The function to be memoized.
Returns:
Callable[P, T]: The wrapped function with memoization.
"""
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Create a cache key based on the function arguments
key = (
f"{func.__name__}({', '.join(map(repr, args))}"
f"{', ' if args and kwargs else ''}"
f"{', '.join(f'{k}={v!r}' for k, v in kwargs.items())})"
)
if db.get(key) is None:
# Compute and store the result in the mock database if not cached
db.set(key, func(*args, **kwargs))
return db.get(key)
return wrapper
if __name__ == "__main__":
from time import time, sleep
# Simulated heavy API call
def heavy_api_call(param: str) -> Dict[str, str]:
"""
Simulates a dynamic, time-consuming API call with fixed response.
"""
sleep(2)
status = "Success"
return {
"result": f"Response for {param}",
"status": status,
"processing_time": "2.00 seconds",
}
# Define a function without memoization to demonstrate the performance impact
def api_call_without_memoization(param: str) -> Dict[str, str]:
"""
Calls the heavy_api_call function without using memoization.
Args:
param (str): The parameter for the API call.
Returns:
Dict[str, str]: The result of the API call.
"""
return heavy_api_call(param)
# Define a memoized API call function using the @memoize decorator
@memoize
def api_call(param: str) -> Dict[str, str]:
"""
Calls the heavy_api_call function with memoization to cache results.
Args:
param (str): The parameter for the API call.
Returns:
Dict[str, str]: The result of the API call.
"""
return heavy_api_call(param)
table = Table(title="Memoization Performance", box=box.MINIMAL_DOUBLE_HEAD)
table.add_column("Call", style="cyan", no_wrap=True)
table.add_column("Function Signature", style="green")
table.add_column("Time Taken (seconds)", justify="right", style="magenta")
# Display cache status
cached_keys = db.get_all_keys()
if cached_keys:
console.print(
f"[bold blue]Loaded {len(cached_keys)} cached result(s) from the database.[/bold blue]"
)
else:
console.print(
"[bold blue]No cached results found. Starting fresh...[/bold blue]"
)
with tqdm(total=4, desc="Processing") as pbar:
# Measure and print the time taken to perform the API call without memoization
start_time = time()
result = api_call_without_memoization("param_one")
end_time = time()
time_without = round(end_time - start_time, 2)
table.add_row(
"Without Memoization",
'api_call_without_memoization("param_one")',
f"[cyan]{str(time_without)}[/cyan]",
)
pbar.update(1)
# Measure and print the time taken to perform the API call with memoization
start_time = time()
result = api_call("param_two")
end_time = time()
time_with = round(end_time - start_time, 2)
table.add_row(
"With Memoization (First Call)",
'api_call("param_two")',
f"[cyan]{str(time_with)}[/cyan]", # Changed to cyan
)
# console.print(result) # Removed logging
pbar.update(1)
# Measure and print the time taken to perform the API call with the same parameter (cached)
start_time = time()
result = api_call("param_two")
end_time = time()
time_cached = round(end_time - start_time, 2)
table.add_row(
"With Memoization (Cached Call)",
'api_call("param_two")',
f"[cyan]{str(time_cached)}[/cyan]", # Changed to cyan
)
pbar.update(1)
# Additional call to demonstrate caching with a different parameter
start_time = time()
result = api_call("param_three")
end_time = time()
time_with_new = round(end_time - start_time, 2)
table.add_row(
"With Memoization (First Call)",
'api_call("param_three")',
f"[cyan]{str(time_with_new)}[/cyan]", # Changed to cyan
)
pbar.update(1)
console.print(table)
console.print(
"[bold green]Summary:[/bold green]\n"
f"- Without memoization took {time_without} seconds.\n"
f"- With memoization, first call for 'param_two' took {time_with} seconds.\n"
f"- With memoization, cached call for 'param_two' took {time_cached} seconds.\n"
f"- With memoization, first call for 'param_three' took {time_with_new} seconds.\n\n"
"[italic yellow]Note: Cached results are persisted in 'cache.db'. If you restart the script, previously cached results will be used, making the first memoized calls instant.[/italic yellow]"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment