Skip to content

Instantly share code, notes, and snippets.

@mentix02
Last active January 18, 2026 02:30
Show Gist options
  • Select an option

  • Save mentix02/52799b9728a614d0fb6c721d2db270a4 to your computer and use it in GitHub Desktop.

Select an option

Save mentix02/52799b9728a614d0fb6c721d2db270a4 to your computer and use it in GitHub Desktop.
Redis who?
import time
import heapq
import unittest
import threading
from typing import Any, Optional
from collections.abc import Hashable
class _MissingType:
"""
A singleton sentinel class to represent a missing value. A key for the ExpiryDict
could be set to None, so we need a distinct value to represent "not found". This
class serves that purpose.
"""
__slots__ = ()
def __repr__(self):
return "MISSING"
def __bool__(self):
return False
Missing = _MissingType()
class ExpiryDict:
"""
A dictionary-like object where each key-value pair can have an optional time-to-live (TTL).
Once the TTL expires, the key-value pair is automatically removed from the dictionary. Requires
a context manager to operate to efficiently manage background cleanup / termination.
Uses an active-passive cleanup strategy:
- Passive: On access, checks if the key has expired and removes it if necessary.
- Active: A background thread periodically checks for expired keys and removes them.
Spawns a background thread upon entering the context manager to handle active cleanup.
"""
_DEFAULT_CLEANUP_INTERVAL = 0.5 # seconds
# Private methods
def _expire(self, key: Hashable) -> bool:
"""
Expects the lock to be held by the caller.
Return True if the key was present and deleted, False otherwise.
"""
if key in self._d:
del self._d[key]
return True
return False
def _cleanup(self):
while not self._stop_event.wait(timeout=self._cleanup_interval):
self._remove_expired()
def _remove_expired(self):
while True:
now = time.monotonic()
key_to_expire = None
with self._lock:
# Heap is empty
if not self._expiry_heap:
break
expires_at, key = self._expiry_heap[0] # Peek
# The earliest entry hasn't expired yet - done for now
if expires_at > now:
break
# Pop the expired entry from heap
heapq.heappop(self._expiry_heap)
# Verify it's not a stale entry (overwritten key)
if key in self._d:
_, current_expires_at = self._d[key]
if current_expires_at == expires_at:
key_to_expire = key
# Perform the actual deletion outside the heap-peek lock if necessary
# or in short bursts to allow other threads to jump in.
if key_to_expire:
with self._lock:
self._d.pop(key_to_expire, None)
# Public methods
def get(self, key: Hashable, default: Any = Missing) -> Any | _MissingType:
now = time.monotonic()
with self._lock:
entry = self._d.get(key)
if entry is None:
return default
val, expires_at = entry
if expires_at and expires_at <= now:
self._expire(key)
return default
return val
def delete(self, key: Hashable) -> bool:
with self._lock:
return self._expire(key)
def set(self, key: Hashable, value: Any, ttl: Optional[float] = None):
"""
Set a key-value pair with an optional TTL (in seconds).
"""
expires_at = time.monotonic() + ttl if ttl is not None else None
with self._lock:
self._d[key] = (value, expires_at)
if expires_at is not None:
heapq.heappush(self._expiry_heap, (expires_at, key))
# Dunder (magic) methods
def __init__(self, *, cleanup_interval: float = _DEFAULT_CLEANUP_INTERVAL):
# Cleanup interval for the background thread.
# A lower value means more frequent checks for expired keys.
self._cleanup_interval = cleanup_interval
# The internal dict - each key is just the key
# the value is a pair: first element the value
# & second element the time-to-expiry (or None)
self._d = dict[Hashable, tuple[Any, Optional[float]]]()
# Priority queue to track expiration times: (expires_at, key)
# We don't manually remove entries from here on update/delete;
# instead, we check for staleness during cleanup. This is an
# intentional design tradeoff to keep operations O(log n) on
# average at the cost of some extra memory usage.
self._expiry_heap = list[tuple[float, Hashable]]()
def __enter__(self):
# Spawn the cleanup thread
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._cleaner = threading.Thread(target=self._cleanup, daemon=True)
self._cleaner.start()
return self
def __exit__(self, *exc):
# Stop the cleanup thread
self._stop_event.set()
self._cleaner.join()
def __delitem__(self, key: Hashable):
if not self.delete(key):
raise KeyError(key)
def __getitem__(self, key: Hashable) -> Any:
val = self.get(key)
if val is Missing:
raise KeyError(key)
return val
def __setitem__(self, key: Hashable, value: Any):
self.set(key, value)
def __contains__(self, key: Hashable) -> bool:
return self.get(key) is not Missing
class TestExpiryDict(unittest.TestCase):
def test_set_get(self):
with ExpiryDict() as ed:
ed.set("key1", "value1", ttl=0.1) # 1 second TTL
self.assertEqual(ed.get("key1"), "value1")
time.sleep(1.5)
self.assertIs(ed.get("key1"), Missing)
with self.assertRaises(KeyError):
ed["key1"]
def test_delete(self):
with ExpiryDict() as ed:
ed.set("key2", "value2")
self.assertTrue(ed.delete("key2"))
self.assertIs(ed.get("key2"), Missing)
with self.assertRaises(KeyError):
del ed["key2_invalid"]
# Test success with __delitem__
ed.set("key2", "value2")
del ed["key2"]
# Should return False now
self.assertFalse(ed.delete("key2"))
def test_contains(self):
with ExpiryDict() as ed:
ed.set("key3", "value3", ttl=0.5) # 0.5 second TTL
self.assertIn("key3", ed)
time.sleep(1)
self.assertNotIn("key3", ed)
def test_updated_ttl(self):
with ExpiryDict() as ed:
ed.set("key4", "value4", ttl=0.5)
time.sleep(0.3)
ed.set("key4", "value4_updated", ttl=0.5)
time.sleep(0.3)
self.assertEqual(ed["key4"], "value4_updated")
time.sleep(0.3)
self.assertIs(ed.get("key4"), Missing)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment