Skip to content

Instantly share code, notes, and snippets.

@tobalsan
Last active November 5, 2025 17:38
Show Gist options
  • Select an option

  • Save tobalsan/c450445e5f7dd93bcd4b678d92bbaae4 to your computer and use it in GitHub Desktop.

Select an option

Save tobalsan/c450445e5f7dd93bcd4b678d92bbaae4 to your computer and use it in GitHub Desktop.
Override of LangGraph native RedisSaver module to fix Interrupt deserialization. See https://github.com/redis-developer/langgraph-redis/issues/113
"""Factory for creating and caching RedisSaver checkpointer instances.
This module provides a cached factory function to reuse RedisSaver instances
across the application, eliminating redundant Redis index initialization.
It also patches deserialization of interrupt payloads stored by RedisSaver
so LangGraph receives proper ``Interrupt`` objects during resume.
"""
from functools import lru_cache
from os import getenv
from typing import Dict, Iterable, List, Optional, Tuple
from langgraph.checkpoint.base import CheckpointTuple, PendingWrite
from langgraph.checkpoint.redis import RedisSaver as _BaseRedisSaver
from langgraph.checkpoint.serde.types import INTERRUPT
from langgraph.types import Interrupt
class RedisSaver(_BaseRedisSaver):
"""RedisSaver with interrupt rehydration fix.
A regression in langgraph 1.0.x causes interrupts persisted via Redis
to be deserialized as plain dictionaries. LangGraph expects instances
of :class:`langgraph.types.Interrupt` (or objects exposing ``.id``),
so resuming a workflow would crash with ``'dict' object has no attribute 'id'``.
This subclass patches the key read paths to rebuild ``Interrupt`` objects
before LangGraph consumes pending writes.
"""
@staticmethod
def _rehydrate_interrupts(writes: Iterable[PendingWrite]) -> List[PendingWrite]:
"""Convert any serialized interrupt dicts back into Interrupt objects."""
changed = False
repaired: List[PendingWrite] = []
for task_id, channel, value in writes:
if channel == INTERRUPT:
if isinstance(value, list):
fixed_list = []
for item in value:
if isinstance(item, Interrupt):
fixed_list.append(item)
elif isinstance(item, dict) and "id" in item:
fixed_list.append(Interrupt(**item))
changed = True
else:
fixed_list.append(item)
value = fixed_list
elif isinstance(value, dict) and "id" in value:
value = [Interrupt(**value)]
changed = True
repaired.append((task_id, channel, value))
return repaired if changed else list(writes)
def get_tuple(self, config) -> CheckpointTuple | None: # type: ignore[override]
result = super().get_tuple(config)
if result and result.pending_writes:
fixed_writes = self._rehydrate_interrupts(result.pending_writes)
if fixed_writes is not result.pending_writes:
result = result._replace(pending_writes=fixed_writes)
return result
def _load_pending_writes_with_registry_check( # type: ignore[override]
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
checkpoint_has_writes: bool,
registry_has_writes: bool,
) -> List[PendingWrite]:
writes = super()._load_pending_writes_with_registry_check(
thread_id,
checkpoint_ns,
checkpoint_id,
checkpoint_has_writes,
registry_has_writes,
)
return self._rehydrate_interrupts(writes)
def _batch_load_pending_writes( # type: ignore[override]
self, batch_keys: List[Tuple[str, str, str]]
) -> Dict[Tuple[str, str, str], List[PendingWrite]]:
results = super()._batch_load_pending_writes(batch_keys)
if not results:
return results
for key, writes in list(results.items()):
if writes:
results[key] = self._rehydrate_interrupts(writes)
return results
@lru_cache(maxsize=4)
def _get_checkpointer_cached(redis_url: str) -> RedisSaver:
"""Internal cached function that creates RedisSaver for normalized URL.
Args:
redis_url: Normalized Redis connection URL (not None)
Returns:
Cached or new RedisSaver instance.
"""
checkpointer = RedisSaver(redis_url)
checkpointer.setup()
return checkpointer
def get_checkpointer(redis_url: Optional[str] = None) -> RedisSaver:
"""Get or create cached RedisSaver for given redis_url.
Uses LRU cache to reuse checkpointer instances, reducing redundant
Redis index initialization (avoids duplicate "Index already exists" logs).
Normalizes redis_url before caching to ensure that None and the default
URL string resolve to the same cached instance.
Supports up to 4 different redis URLs (prod, test, dev, etc) to preserve
test flexibility when tests use custom redis_url values.
Args:
redis_url: Optional Redis connection URL. If None, uses REDIS_URL env
or "redis://localhost:6379/0"
Returns:
Cached or new RedisSaver instance.
Note:
Tests should call _get_checkpointer_cached.cache_clear() after monkeypatching
redis_url to ensure fresh instances.
Example:
>>> # Normal usage
>>> checkpointer = get_checkpointer()
>>> graph = create_graph(checkpointer=checkpointer)
>>>
>>> # Test usage with custom Redis
>>> def test_with_custom_redis():
... from infra.checkpointer_factory import _get_checkpointer_cached
... _get_checkpointer_cached.cache_clear()
... checkpointer = get_checkpointer(redis_url="redis://test:6379")
... # ... test code ...
... _get_checkpointer_cached.cache_clear()
"""
# Normalize URL before cache lookup to ensure None and default resolve to same entry
normalized_url = redis_url or getenv("REDIS_URL", "redis://localhost:6379/0")
return _get_checkpointer_cached(normalized_url)
__all__ = ["get_checkpointer", "_get_checkpointer_cached"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment