Skip to content

Instantly share code, notes, and snippets.

@achadha235
Last active June 28, 2023 21:36
Show Gist options
  • Select an option

  • Save achadha235/7375ae1694beba49b0988674a04e75dc to your computer and use it in GitHub Desktop.

Select an option

Save achadha235/7375ae1694beba49b0988674a04e75dc to your computer and use it in GitHub Desktop.
A simple and fast redis based caching system with TTL and cron-based expiry
import os
import json
import pytz
import redis
import pickle
import hashlib
import datetime as dt
from typing import Optional, Union
from croniter import croniter
from urllib.parse import urlparse
REDIS_URL = os.getenv("REDIS_URL", "redis://locahost:6379")
url = urlparse(REDIS_URL)
default_client = redis.Redis.from_url(REDIS_URL, socket_connect_timeout=3.0)
default_client.ping()
def cache(
client=default_client,
loads=pickle.loads,
dumps=pickle.dumps,
expire_at: Optional[Union[str, dt.datetime]] = None,
ttl: int = None,
tz=pytz.timezone("America/New_York"),
prefix="redis-cache",
):
"""decorator to cache function
Args:
client (_type_, optional): Defaults to default_client.
loads (_type_, optional): Defaults to pickle.loads.
dumps (_type_, optional): Defaults to pickle.dumps.
ttl (_type_, optional): ttl in seconds. Defaults to None.
prefix (str, optional): Defaults to "stock-data".
"""
def _cache(f):
fn_id = f.__module__ + "." + f.__qualname__
def get_key(args, kwargs):
arg_hash = hashlib.sha256(
str(
(args, json.dumps(remove_reserved_keys(kwargs), sort_keys=True))
).encode("utf-8")
).hexdigest()
return f"{prefix}:{fn_id}:{arg_hash}"
def remove_reserved_keys(kwargs):
return {
k: v
for (k, v) in kwargs.items()
if k not in ["on_cache_hit", "no_cache"]
}
def get_ttl_seconds(args, kwargs, res) -> Optional[int]:
expiry = None
if expire_at:
now = dt.datetime.now(tz=tz)
if isinstance(expire_at, dt.datetime):
if expire_at > now:
expiry = int((expire_at - now).total_seconds())
elif isinstance(expire_at, str):
next_time = croniter(expire_at, now).get_next(dt.datetime)
expiry = int((next_time - now).total_seconds())
if ttl:
if callable(ttl):
ttl_seconds = ttl(args, kwargs, res)
else:
ttl_seconds = int(ttl)
if ttl_seconds and expiry:
expiry = int(min(ttl_seconds, expiry))
elif isinstance(ttl_seconds, (int, float)):
expiry = int(ttl_seconds)
elif isinstance(ttl, (int, float)):
expiry = int(ttl)
return expiry
def __cache(*args, **kwargs):
key = get_key(args, kwargs)
def evaluate():
return f(*args, **remove_reserved_keys(kwargs))
def evaluate_and_cache():
res = evaluate()
pickled = dumps(res)
set_params = {}
cache_ttl = get_ttl_seconds(args, kwargs, res)
if cache_ttl:
set_params["ex"] = cache_ttl ## expiry for the key (in seconds)
client.set(key, pickled, **set_params)
return res
if kwargs.get("no_cache", False):
new_val = evaluate_and_cache()
return new_val
if existing := client.get(key):
existing_val = loads(existing)
if callable(kwargs.get("on_cache_hit", None)):
should_refresh = kwargs["on_cache_hit"](existing_val)
if not should_refresh:
return existing_val
else:
new_val = evaluate_and_cache()
return new_val
return existing_val
else:
new_val = evaluate_and_cache()
return new_val
def invalidate(*args, **kwargs):
client.delete(get_key(args, kwargs))
def invalidate_all():
prefix_blob = f"{prefix}:{fn_id}:*"
for key in client.scan_iter(prefix_blob):
client.delete(key)
__cache.__setattr__("invalidate", invalidate)
__cache.__setattr__("invalidate_all", invalidate_all)
return __cache
return _cache
def invalidate_all(client=default_client, prefix="stock-data") -> int:
def clear_ns(ns):
"""
Clears a namespace
:param ns: str, namespace i.e your:prefix
:return: int, cleared keys
"""
count = 0
for key in client.scan_iter(ns + "*"):
client.delete(key)
count += 1
return count
return clear_ns(f"{prefix}:")
from datetime import timedelta
import datetime as dt
SIX_PM_ON_WEEKDAYS_CRON = "0 6 * * *"
EVERY_HALF_HOUR_DURING_TRADING = "30,0 9-16 * * 1-5"
ONE_HOUR_SECONDS = dt.timedelta(hours=1).total_seconds()
TEN_MINS_SECONDS = dt.timedelta(minutes=1).total_seconds()
@cache(ttl=ONE_HOUR_SECONDS, expire_at=REFRESH_CRON)
def hello_world(val):
return val
def on_cache_hit(cached_val):
## Forces the cache to invalidate if cached value is 15
if cached_val == 15:
return True
else:
return False
hello_world(25, on_cache_hit=on_cache_hit)
hello_world(15, on_cache_hit=on_cache_hit)
hello_world(18, on_cache_hit=on_cache_hit)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment