Last active
November 5, 2025 17:38
-
-
Save tobalsan/c450445e5f7dd93bcd4b678d92bbaae4 to your computer and use it in GitHub Desktop.
Override of LangGraph native RedisSaver module to fix Interrupt deserialization. See https://github.com/redis-developer/langgraph-redis/issues/113
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Factory for creating and caching RedisSaver checkpointer instances. | |
| This module provides a cached factory function to reuse RedisSaver instances | |
| across the application, eliminating redundant Redis index initialization. | |
| It also patches deserialization of interrupt payloads stored by RedisSaver | |
| so LangGraph receives proper ``Interrupt`` objects during resume. | |
| """ | |
| from functools import lru_cache | |
| from os import getenv | |
| from typing import Dict, Iterable, List, Optional, Tuple | |
| from langgraph.checkpoint.base import CheckpointTuple, PendingWrite | |
| from langgraph.checkpoint.redis import RedisSaver as _BaseRedisSaver | |
| from langgraph.checkpoint.serde.types import INTERRUPT | |
| from langgraph.types import Interrupt | |
| class RedisSaver(_BaseRedisSaver): | |
| """RedisSaver with interrupt rehydration fix. | |
| A regression in langgraph 1.0.x causes interrupts persisted via Redis | |
| to be deserialized as plain dictionaries. LangGraph expects instances | |
| of :class:`langgraph.types.Interrupt` (or objects exposing ``.id``), | |
| so resuming a workflow would crash with ``'dict' object has no attribute 'id'``. | |
| This subclass patches the key read paths to rebuild ``Interrupt`` objects | |
| before LangGraph consumes pending writes. | |
| """ | |
| @staticmethod | |
| def _rehydrate_interrupts(writes: Iterable[PendingWrite]) -> List[PendingWrite]: | |
| """Convert any serialized interrupt dicts back into Interrupt objects.""" | |
| changed = False | |
| repaired: List[PendingWrite] = [] | |
| for task_id, channel, value in writes: | |
| if channel == INTERRUPT: | |
| if isinstance(value, list): | |
| fixed_list = [] | |
| for item in value: | |
| if isinstance(item, Interrupt): | |
| fixed_list.append(item) | |
| elif isinstance(item, dict) and "id" in item: | |
| fixed_list.append(Interrupt(**item)) | |
| changed = True | |
| else: | |
| fixed_list.append(item) | |
| value = fixed_list | |
| elif isinstance(value, dict) and "id" in value: | |
| value = [Interrupt(**value)] | |
| changed = True | |
| repaired.append((task_id, channel, value)) | |
| return repaired if changed else list(writes) | |
| def get_tuple(self, config) -> CheckpointTuple | None: # type: ignore[override] | |
| result = super().get_tuple(config) | |
| if result and result.pending_writes: | |
| fixed_writes = self._rehydrate_interrupts(result.pending_writes) | |
| if fixed_writes is not result.pending_writes: | |
| result = result._replace(pending_writes=fixed_writes) | |
| return result | |
| def _load_pending_writes_with_registry_check( # type: ignore[override] | |
| self, | |
| thread_id: str, | |
| checkpoint_ns: str, | |
| checkpoint_id: str, | |
| checkpoint_has_writes: bool, | |
| registry_has_writes: bool, | |
| ) -> List[PendingWrite]: | |
| writes = super()._load_pending_writes_with_registry_check( | |
| thread_id, | |
| checkpoint_ns, | |
| checkpoint_id, | |
| checkpoint_has_writes, | |
| registry_has_writes, | |
| ) | |
| return self._rehydrate_interrupts(writes) | |
| def _batch_load_pending_writes( # type: ignore[override] | |
| self, batch_keys: List[Tuple[str, str, str]] | |
| ) -> Dict[Tuple[str, str, str], List[PendingWrite]]: | |
| results = super()._batch_load_pending_writes(batch_keys) | |
| if not results: | |
| return results | |
| for key, writes in list(results.items()): | |
| if writes: | |
| results[key] = self._rehydrate_interrupts(writes) | |
| return results | |
| @lru_cache(maxsize=4) | |
| def _get_checkpointer_cached(redis_url: str) -> RedisSaver: | |
| """Internal cached function that creates RedisSaver for normalized URL. | |
| Args: | |
| redis_url: Normalized Redis connection URL (not None) | |
| Returns: | |
| Cached or new RedisSaver instance. | |
| """ | |
| checkpointer = RedisSaver(redis_url) | |
| checkpointer.setup() | |
| return checkpointer | |
| def get_checkpointer(redis_url: Optional[str] = None) -> RedisSaver: | |
| """Get or create cached RedisSaver for given redis_url. | |
| Uses LRU cache to reuse checkpointer instances, reducing redundant | |
| Redis index initialization (avoids duplicate "Index already exists" logs). | |
| Normalizes redis_url before caching to ensure that None and the default | |
| URL string resolve to the same cached instance. | |
| Supports up to 4 different redis URLs (prod, test, dev, etc) to preserve | |
| test flexibility when tests use custom redis_url values. | |
| Args: | |
| redis_url: Optional Redis connection URL. If None, uses REDIS_URL env | |
| or "redis://localhost:6379/0" | |
| Returns: | |
| Cached or new RedisSaver instance. | |
| Note: | |
| Tests should call _get_checkpointer_cached.cache_clear() after monkeypatching | |
| redis_url to ensure fresh instances. | |
| Example: | |
| >>> # Normal usage | |
| >>> checkpointer = get_checkpointer() | |
| >>> graph = create_graph(checkpointer=checkpointer) | |
| >>> | |
| >>> # Test usage with custom Redis | |
| >>> def test_with_custom_redis(): | |
| ... from infra.checkpointer_factory import _get_checkpointer_cached | |
| ... _get_checkpointer_cached.cache_clear() | |
| ... checkpointer = get_checkpointer(redis_url="redis://test:6379") | |
| ... # ... test code ... | |
| ... _get_checkpointer_cached.cache_clear() | |
| """ | |
| # Normalize URL before cache lookup to ensure None and default resolve to same entry | |
| normalized_url = redis_url or getenv("REDIS_URL", "redis://localhost:6379/0") | |
| return _get_checkpointer_cached(normalized_url) | |
| __all__ = ["get_checkpointer", "_get_checkpointer_cached"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment