Last active
June 28, 2023 21:36
-
-
Save achadha235/7375ae1694beba49b0988674a04e75dc to your computer and use it in GitHub Desktop.
A simple and fast redis based caching system with TTL and cron-based expiry
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import json | |
| import pytz | |
| import redis | |
| import pickle | |
| import hashlib | |
| import datetime as dt | |
| from typing import Optional, Union | |
| from croniter import croniter | |
| from urllib.parse import urlparse | |
| REDIS_URL = os.getenv("REDIS_URL", "redis://locahost:6379") | |
| url = urlparse(REDIS_URL) | |
| default_client = redis.Redis.from_url(REDIS_URL, socket_connect_timeout=3.0) | |
| default_client.ping() | |
| def cache( | |
| client=default_client, | |
| loads=pickle.loads, | |
| dumps=pickle.dumps, | |
| expire_at: Optional[Union[str, dt.datetime]] = None, | |
| ttl: int = None, | |
| tz=pytz.timezone("America/New_York"), | |
| prefix="redis-cache", | |
| ): | |
| """decorator to cache function | |
| Args: | |
| client (_type_, optional): Defaults to default_client. | |
| loads (_type_, optional): Defaults to pickle.loads. | |
| dumps (_type_, optional): Defaults to pickle.dumps. | |
| ttl (_type_, optional): ttl in seconds. Defaults to None. | |
| prefix (str, optional): Defaults to "stock-data". | |
| """ | |
| def _cache(f): | |
| fn_id = f.__module__ + "." + f.__qualname__ | |
| def get_key(args, kwargs): | |
| arg_hash = hashlib.sha256( | |
| str( | |
| (args, json.dumps(remove_reserved_keys(kwargs), sort_keys=True)) | |
| ).encode("utf-8") | |
| ).hexdigest() | |
| return f"{prefix}:{fn_id}:{arg_hash}" | |
| def remove_reserved_keys(kwargs): | |
| return { | |
| k: v | |
| for (k, v) in kwargs.items() | |
| if k not in ["on_cache_hit", "no_cache"] | |
| } | |
| def get_ttl_seconds(args, kwargs, res) -> Optional[int]: | |
| expiry = None | |
| if expire_at: | |
| now = dt.datetime.now(tz=tz) | |
| if isinstance(expire_at, dt.datetime): | |
| if expire_at > now: | |
| expiry = int((expire_at - now).total_seconds()) | |
| elif isinstance(expire_at, str): | |
| next_time = croniter(expire_at, now).get_next(dt.datetime) | |
| expiry = int((next_time - now).total_seconds()) | |
| if ttl: | |
| if callable(ttl): | |
| ttl_seconds = ttl(args, kwargs, res) | |
| else: | |
| ttl_seconds = int(ttl) | |
| if ttl_seconds and expiry: | |
| expiry = int(min(ttl_seconds, expiry)) | |
| elif isinstance(ttl_seconds, (int, float)): | |
| expiry = int(ttl_seconds) | |
| elif isinstance(ttl, (int, float)): | |
| expiry = int(ttl) | |
| return expiry | |
| def __cache(*args, **kwargs): | |
| key = get_key(args, kwargs) | |
| def evaluate(): | |
| return f(*args, **remove_reserved_keys(kwargs)) | |
| def evaluate_and_cache(): | |
| res = evaluate() | |
| pickled = dumps(res) | |
| set_params = {} | |
| cache_ttl = get_ttl_seconds(args, kwargs, res) | |
| if cache_ttl: | |
| set_params["ex"] = cache_ttl ## expiry for the key (in seconds) | |
| client.set(key, pickled, **set_params) | |
| return res | |
| if kwargs.get("no_cache", False): | |
| new_val = evaluate_and_cache() | |
| return new_val | |
| if existing := client.get(key): | |
| existing_val = loads(existing) | |
| if callable(kwargs.get("on_cache_hit", None)): | |
| should_refresh = kwargs["on_cache_hit"](existing_val) | |
| if not should_refresh: | |
| return existing_val | |
| else: | |
| new_val = evaluate_and_cache() | |
| return new_val | |
| return existing_val | |
| else: | |
| new_val = evaluate_and_cache() | |
| return new_val | |
| def invalidate(*args, **kwargs): | |
| client.delete(get_key(args, kwargs)) | |
| def invalidate_all(): | |
| prefix_blob = f"{prefix}:{fn_id}:*" | |
| for key in client.scan_iter(prefix_blob): | |
| client.delete(key) | |
| __cache.__setattr__("invalidate", invalidate) | |
| __cache.__setattr__("invalidate_all", invalidate_all) | |
| return __cache | |
| return _cache | |
| def invalidate_all(client=default_client, prefix="stock-data") -> int: | |
| def clear_ns(ns): | |
| """ | |
| Clears a namespace | |
| :param ns: str, namespace i.e your:prefix | |
| :return: int, cleared keys | |
| """ | |
| count = 0 | |
| for key in client.scan_iter(ns + "*"): | |
| client.delete(key) | |
| count += 1 | |
| return count | |
| return clear_ns(f"{prefix}:") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from datetime import timedelta | |
| import datetime as dt | |
| SIX_PM_ON_WEEKDAYS_CRON = "0 6 * * *" | |
| EVERY_HALF_HOUR_DURING_TRADING = "30,0 9-16 * * 1-5" | |
| ONE_HOUR_SECONDS = dt.timedelta(hours=1).total_seconds() | |
| TEN_MINS_SECONDS = dt.timedelta(minutes=1).total_seconds() | |
| @cache(ttl=ONE_HOUR_SECONDS, expire_at=REFRESH_CRON) | |
| def hello_world(val): | |
| return val | |
| def on_cache_hit(cached_val): | |
| ## Forces the cache to invalidate if cached value is 15 | |
| if cached_val == 15: | |
| return True | |
| else: | |
| return False | |
| hello_world(25, on_cache_hit=on_cache_hit) | |
| hello_world(15, on_cache_hit=on_cache_hit) | |
| hello_world(18, on_cache_hit=on_cache_hit) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment