Last active
April 16, 2025 17:38
-
-
Save STHITAPRAJNAS/8fc764a847578b67a88e1ad16dc26a0b to your computer and use it in GitHub Desktop.
dynamodb-langgraph-async
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| import pickle | |
| import time # Needed for TTL calculation example | |
| from datetime import datetime, timezone, timedelta # Needed for TTL calculation example | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncGenerator, Optional, Sequence, Tuple, Dict, Any, List, Union, TypedDict # Added TypedDict | |
| import aiobotocore.session | |
| from botocore.exceptions import ClientError | |
| from boto3.dynamodb.types import TypeDeserializer, TypeSerializer | |
| from langgraph.serde.base import BaseSerializer | |
| from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata, CheckpointTuple | |
| # --- Constants for DynamoDB Table Structure (matching the original library defaults) --- | |
| DEFAULT_TABLE_NAME = "langgraph_checkpoints" | |
| DEFAULT_PK = "thread_id" | |
| # --- NEW: Default name for the Sort Key attribute --- | |
| DEFAULT_SK_NAME = "sk" | |
| DEFAULT_SK_CHECKPOINT_VALUE = "checkpoint" # Default value for the main checkpoint sort key | |
| DEFAULT_SK_METADATA_PREFIX = "metadata|" # Default prefix for metadata sort key values | |
| DEFAULT_TTL_KEY = "ttl_timestamp" # Optional TTL attribute | |
| class PickleSerializer(BaseSerializer): | |
| """Serializer that uses pickle.""" | |
| def dumps(self, obj: Any) -> bytes: | |
| return pickle.dumps(obj) | |
| def loads(self, data: bytes) -> Any: | |
| return pickle.loads(data) | |
| # --- Define CheckpointMetadata structure for clarity --- | |
| # Based on typical LangGraph usage | |
| class CheckpointMetadata(TypedDict, total=False): | |
| source: str # The source of the write (e.g., "input", "loop", "update") | |
| step: int # The step number | |
| writes: Dict[str, Any] # The writes made by the step | |
| score: Optional[int] # Optional score | |
| config: Dict[str, Any] # The config associated with the metadata entry | |
| class AsyncDynamoDBSaver(BaseCheckpointSaver): | |
| """ | |
| An asynchronous checkpoint saver that stores checkpoints in DynamoDB, | |
| allowing configuration of the sort key attribute name. | |
| Args: | |
| table_name (str): The name of the DynamoDB table. Defaults to "langgraph_checkpoints". | |
| primary_key (str): The name of the partition key attribute. Defaults to "thread_id". | |
| sort_key_name (str): The name of the sort key attribute. Defaults to "sk". | |
| sort_key_checkpoint_value (str): The sort key value for the main checkpoint data. Defaults to "checkpoint". | |
| sort_key_metadata_prefix (str): The prefix for sort key values storing metadata. Defaults to "metadata|". | |
| ttl_key (Optional[str]): The name of the attribute to use for TTL. If None, TTL is not used. Defaults to "ttl_timestamp". | |
| ttl_duration (Optional[timedelta]): Duration for TTL. If set, items expire after this duration from write time. Defaults to None. | |
| serializer (Optional[BaseSerializer]): The serializer to use for checkpoint data. Defaults to PickleSerializer. | |
| aws_region (Optional[str]): AWS region name. If None, uses default from environment/config. | |
| aws_access_key_id (Optional[str]): AWS access key ID. If None, uses default from environment/config. | |
| aws_secret_access_key (Optional[str]): AWS secret access key. If None, uses default from environment/config. | |
| endpoint_url (Optional[str]): Custom DynamoDB endpoint URL (e.g., for DynamoDB Local). | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| table_name: str = DEFAULT_TABLE_NAME, | |
| primary_key: str = DEFAULT_PK, | |
| # --- NEW: Parameter for Sort Key Name --- | |
| sort_key_name: str = DEFAULT_SK_NAME, | |
| sort_key_checkpoint_value: str = DEFAULT_SK_CHECKPOINT_VALUE, | |
| sort_key_metadata_prefix: str = DEFAULT_SK_METADATA_PREFIX, | |
| ttl_key: Optional[str] = DEFAULT_TTL_KEY, | |
| ttl_duration: Optional[timedelta] = None, # Added TTL duration | |
| serializer: Optional[BaseSerializer] = None, | |
| aws_region: Optional[str] = None, | |
| aws_access_key_id: Optional[str] = None, | |
| aws_secret_access_key: Optional[str] = None, | |
| endpoint_url: Optional[str] = None, | |
| ): | |
| super().__init__(serializer=serializer or PickleSerializer()) | |
| self.table_name = table_name | |
| self.primary_key = primary_key | |
| # --- Store the Sort Key Name --- | |
| self.sort_key_name = sort_key_name | |
| self.sort_key_checkpoint_value = sort_key_checkpoint_value | |
| self.sort_key_metadata_prefix = sort_key_metadata_prefix | |
| self.ttl_key = ttl_key | |
| self.ttl_duration = ttl_duration # Store TTL duration | |
| self.serializer = self.serde | |
| # --- aiobotocore Session and Client Setup --- | |
| self.session = aiobotocore.session.get_session() | |
| self.client_config = { | |
| "service_name": "dynamodb", | |
| "region_name": aws_region, | |
| "aws_access_key_id": aws_access_key_id, | |
| "aws_secret_access_key": aws_secret_access_key, | |
| "endpoint_url": endpoint_url, | |
| } | |
| # --- DynamoDB Type Serializer/Deserializer --- | |
| self._type_serializer = TypeSerializer() | |
| self._type_deserializer = TypeDeserializer() | |
| @asynccontextmanager | |
| async def _get_client(self) -> AsyncGenerator[Any, None]: | |
| """Provides an asynchronous DynamoDB client context.""" | |
| async with self.session.create_client(**self.client_config) as client: | |
| yield client | |
| def _serialize_item(self, item: Dict[str, Any]) -> Dict[str, Dict]: | |
| """Serializes a Python dict into DynamoDB attribute value format.""" | |
| # Filter out None values before serialization, as DynamoDB doesn't like them | |
| # unless specifically handled (e.g., Null=True which isn't default for TypeSerializer) | |
| filtered_item = {k: v for k, v in item.items() if v is not None} | |
| return {k: self._type_serializer.serialize(v) for k, v in filtered_item.items()} | |
| def _deserialize_item(self, item: Dict[str, Dict]) -> Dict[str, Any]: | |
| """Deserializes a DynamoDB item into a Python dict.""" | |
| return {k: self._type_deserializer.deserialize(v) for k, v in item.items()} | |
| def _calculate_ttl(self) -> Optional[int]: | |
| """Calculates the TTL timestamp if ttl_key and ttl_duration are set.""" | |
| if self.ttl_key and self.ttl_duration: | |
| expiry_time = datetime.now(timezone.utc) + self.ttl_duration | |
| return int(expiry_time.timestamp()) | |
| return None | |
| async def aget_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]: | |
| """ | |
| Asynchronously retrieves a checkpoint tuple (checkpoint, metadata, parent_config) | |
| for a given thread configuration. | |
| Args: | |
| config: The configuration identifying the thread (must contain 'thread_id'). | |
| Returns: | |
| An optional CheckpointTuple if found, otherwise None. | |
| """ | |
| thread_id = config["thread_id"] | |
| async with self._get_client() as client: | |
| try: | |
| # Query for both checkpoint and metadata items for the thread_id | |
| response = await client.query( | |
| TableName=self.table_name, | |
| KeyConditionExpression=f"{self.primary_key} = :pk", | |
| ExpressionAttributeValues={ | |
| ":pk": self._type_serializer.serialize(thread_id), | |
| }, | |
| ConsistentRead=True, # Ensure we get the latest data | |
| ) | |
| except ClientError as e: | |
| print(f"Error querying DynamoDB for thread {thread_id}: {e}") | |
| return None | |
| items = [self._deserialize_item(item) for item in response.get("Items", [])] | |
| checkpoint_item = None | |
| metadata_items = [] | |
| # Separate checkpoint and metadata items using the configured sort_key_name | |
| for item in items: | |
| # --- Use self.sort_key_name --- | |
| sort_key_value = item.get(self.sort_key_name) | |
| if sort_key_value == self.sort_key_checkpoint_value: | |
| checkpoint_item = item | |
| elif isinstance(sort_key_value, str) and sort_key_value.startswith(self.sort_key_metadata_prefix): | |
| metadata_items.append(item) | |
| if not checkpoint_item: | |
| print(f"No checkpoint item found for thread {thread_id} with SK value {self.sort_key_checkpoint_value}") | |
| return None # No checkpoint found for this thread_id | |
| # Deserialize checkpoint data | |
| try: | |
| checkpoint_data = checkpoint_item.get("checkpoint") | |
| if isinstance(checkpoint_data, bytes): | |
| checkpoint = self.serializer.loads(checkpoint_data) | |
| elif checkpoint_data is None: | |
| # Create a default empty checkpoint if data is None | |
| checkpoint = Checkpoint(v=1, ts="", id="", channel_values={}, channel_versions={}, seen_recovery_events=set(), config={}) | |
| print(f"Warning: Checkpoint data is None for thread {thread_id}. Using default empty checkpoint.") | |
| else: | |
| print(f"Warning: Unexpected checkpoint data type: {type(checkpoint_data)} for thread {thread_id}") | |
| return None | |
| except (pickle.UnpicklingError, TypeError, EOFError, AttributeError, ModuleNotFoundError) as e: | |
| # Added more specific exceptions | |
| print(f"Error deserializing checkpoint data for thread {thread_id}: {e}") | |
| return None # Corrupted data | |
| # Construct metadata dictionary | |
| metadata: Dict[str, CheckpointMetadata] = {} | |
| for meta_item in metadata_items: | |
| # --- Use self.sort_key_name --- | |
| sort_key_value = meta_item.get(self.sort_key_name) | |
| if not isinstance(sort_key_value, str): continue # Skip if sort key is missing or wrong type | |
| try: | |
| # Extract ts from the sort key value | |
| ts = sort_key_value.split(self.sort_key_metadata_prefix, 1)[1] | |
| meta_data_bytes = meta_item.get("metadata") | |
| if isinstance(meta_data_bytes, bytes): | |
| meta_dict = self.serializer.loads(meta_data_bytes) | |
| # Ensure it's a dictionary, default to empty if deserialization fails or type is wrong | |
| metadata[ts] = meta_dict if isinstance(meta_dict, dict) else {} | |
| else: | |
| # If metadata is not stored as bytes or is missing, use an empty dict | |
| metadata[ts] = {} | |
| except (IndexError, pickle.UnpicklingError, TypeError, EOFError, AttributeError, ModuleNotFoundError) as e: | |
| print(f"Warning: Could not process metadata item {sort_key_value} for thread {thread_id}: {e}") | |
| # Assign empty dict for this potentially corrupted/malformed metadata entry | |
| ts_key = sort_key_value.split(self.sort_key_metadata_prefix, 1)[-1] | |
| if ts_key != sort_key_value: # Check if split actually happened | |
| metadata[ts_key] = {} | |
| # Find the parent config | |
| parent_config = None | |
| # Ensure checkpoint is valid and has a timestamp before proceeding | |
| if checkpoint and hasattr(checkpoint, 'ts') and checkpoint.ts: | |
| sorted_metadata_ts = sorted(metadata.keys(), reverse=True) | |
| try: | |
| # Find the index of the current checkpoint's timestamp in the sorted list | |
| current_checkpoint_idx = sorted_metadata_ts.index(checkpoint.ts) | |
| # The parent is the next one in the sorted list (chronologically older) | |
| if current_checkpoint_idx + 1 < len(sorted_metadata_ts): | |
| parent_ts = sorted_metadata_ts[current_checkpoint_idx + 1] | |
| # Get the config from the parent's metadata entry | |
| parent_config = metadata.get(parent_ts, {}).get("config") | |
| except ValueError: | |
| # This might happen if the checkpoint.ts is not found in metadata keys | |
| # This could indicate an issue with saving or data consistency | |
| print(f"Warning: Checkpoint timestamp {checkpoint.ts} not found in metadata keys for thread {thread_id}.") | |
| # As a fallback, consider the config from the chronologically previous metadata entry, if any | |
| if sorted_metadata_ts: | |
| # Find the first timestamp older than the current checkpoint's ts | |
| older_ts = [ts for ts in sorted_metadata_ts if ts < checkpoint.ts] | |
| if older_ts: | |
| parent_ts = older_ts[0] # The most recent older timestamp | |
| parent_config = metadata.get(parent_ts, {}).get("config") | |
| elif sorted_metadata_ts[0] != checkpoint.ts: | |
| # Fallback: if the latest metadata isn't the current one, maybe it's the parent? | |
| # This logic might be flawed depending on exact saving order. | |
| parent_config = metadata.get(sorted_metadata_ts[0], {}).get("config") | |
| # Ensure the returned config in the tuple is the one from the input args | |
| # The checkpoint.config might be stale if not updated correctly during 'put' | |
| return CheckpointTuple(config=config, checkpoint=checkpoint, metadata=metadata, parent_config=parent_config) | |
| async def alist( | |
| self, | |
| config: Optional[Dict[str, Any]], | |
| *, | |
| filter: Optional[Dict[str, Any]] = None, | |
| before: Optional[Dict[str, Any]] = None, | |
| limit: Optional[int] = None, | |
| ) -> AsyncGenerator[CheckpointTuple, None]: | |
| """ | |
| Asynchronously lists checkpoints for a given thread configuration, | |
| optionally filtered and limited. | |
| Args: | |
| config: The configuration identifying the thread (must contain 'thread_id'). | |
| If None, attempts to scan the table (use with caution). | |
| filter: Optional dictionary for filtering based on metadata attributes | |
| (applied post-retrieval, potentially inefficient). | |
| before: Optional config to list checkpoints strictly before the one specified | |
| by this config's timestamp. | |
| limit: The maximum number of checkpoints to return. | |
| Yields: | |
| CheckpointTuple objects containing config and metadata (checkpoint is None). | |
| """ | |
| if config is None: | |
| # Delegate to scan helper if no specific thread config is provided | |
| print("Warning: Listing checkpoints without a thread_id will perform a DynamoDB Scan, which can be inefficient and costly on large tables.") | |
| # Need to pass self to the generator correctly | |
| async for item in self._scan_all_checkpoints(filter=filter, limit=limit): | |
| yield item | |
| return | |
| thread_id = config.get("thread_id") | |
| if not thread_id: | |
| raise ValueError("Configuration must include 'thread_id' to list checkpoints for a specific thread.") | |
| async with self._get_client() as client: | |
| # --- Use self.sort_key_name in KeyConditionExpression --- | |
| key_condition = f"{self.primary_key} = :pk AND begins_with({self.sort_key_name}, :sk_prefix)" | |
| expression_values = { | |
| ":pk": self._type_serializer.serialize(thread_id), | |
| ":sk_prefix": self._type_serializer.serialize(self.sort_key_metadata_prefix), | |
| } | |
| # --- Filtering Logic --- | |
| if before and before.get("thread_id") == thread_id: | |
| before_tuple = await self.aget_tuple(before) | |
| if before_tuple and before_tuple.checkpoint and hasattr(before_tuple.checkpoint, 'ts') and before_tuple.checkpoint.ts: | |
| before_ts = before_tuple.checkpoint.ts | |
| # --- Use self.sort_key_name --- | |
| key_condition = f"{self.primary_key} = :pk AND {self.sort_key_name} < :sk_before" | |
| # Use the full metadata sort key value for comparison | |
| expression_values[":sk_before"] = self._type_serializer.serialize(f"{self.sort_key_metadata_prefix}{before_ts}") | |
| # Remove the sk_prefix as it's incompatible with the '<' condition on the full key | |
| expression_values.pop(":sk_prefix", None) | |
| else: | |
| print(f"Warning: 'before' config {before} did not resolve to a valid checkpoint timestamp. Ignoring 'before' filter.") | |
| query_kwargs = { | |
| "TableName": self.table_name, | |
| "KeyConditionExpression": key_condition, | |
| "ExpressionAttributeValues": expression_values, | |
| "ScanIndexForward": False, # List newest first | |
| "ConsistentRead": True, | |
| } | |
| # --- Pagination and Limit --- | |
| paginator = client.get_paginator("query") | |
| pages = paginator.paginate(**query_kwargs) | |
| count = 0 | |
| async for page in pages: | |
| items = [self._deserialize_item(item) for item in page.get("Items", [])] | |
| # Already sorted by DynamoDB (ScanIndexForward=False) | |
| for meta_item in items: | |
| if limit is not None and count >= limit: | |
| return | |
| # --- Use self.sort_key_name --- | |
| sort_key_value = meta_item.get(self.sort_key_name) | |
| if not isinstance(sort_key_value, str): continue # Skip malformed | |
| try: | |
| ts = sort_key_value.split(self.sort_key_metadata_prefix, 1)[1] | |
| meta_data_bytes = meta_item.get("metadata") | |
| if isinstance(meta_data_bytes, bytes): | |
| metadata_entry: CheckpointMetadata = self.serializer.loads(meta_data_bytes) | |
| if not isinstance(metadata_entry, dict): metadata_entry = {} | |
| else: | |
| metadata_entry = {} | |
| # --- Apply Post-Retrieval Filter --- | |
| if filter: | |
| match = True | |
| for key, value in filter.items(): | |
| if metadata_entry.get(key) != value: | |
| match = False | |
| break | |
| if not match: | |
| continue | |
| # Construct a config representing this specific point in time | |
| # Use the config stored within the metadata if available, otherwise fallback | |
| list_item_config = metadata_entry.get("config", config) | |
| # Add the timestamp to the config for clarity | |
| list_item_config = {**list_item_config, "checkpoint_ts": ts} | |
| # Yield config/metadata only, as per BaseCheckpointSaver.alist expectation | |
| yield CheckpointTuple(config=list_item_config, checkpoint=None, metadata=metadata_entry, parent_config=None) | |
| count += 1 | |
| except (IndexError, pickle.UnpicklingError, TypeError, EOFError, AttributeError, ModuleNotFoundError) as e: | |
| print(f"Warning: Could not process metadata item {sort_key_value} during list for thread {thread_id}: {e}") | |
| continue | |
| if limit is not None and count >= limit: | |
| break | |
| async def _scan_all_checkpoints( | |
| self, | |
| *, | |
| filter: Optional[Dict[str, Any]] = None, | |
| limit: Optional[int] = None | |
| ) -> AsyncGenerator[CheckpointTuple, None]: | |
| """ Helper to scan the entire table. Yields latest CheckpointTuple per thread_id found. Use with caution. """ | |
| async with self._get_client() as client: | |
| paginator = client.get_paginator("scan") | |
| scan_kwargs = {"TableName": self.table_name} | |
| count = 0 | |
| seen_thread_ids = set() | |
| # Store items per thread_id to process after scanning each page | |
| page_thread_data: Dict[str, Dict[str, Any]] = {} | |
| async for page in paginator.paginate(**scan_kwargs): | |
| items = [self._deserialize_item(item) for item in page.get("Items", [])] | |
| page_thread_data.clear() | |
| # Group items by thread_id found in the current page | |
| for item in items: | |
| thread_id = item.get(self.primary_key) | |
| # --- Use self.sort_key_name --- | |
| sort_key = item.get(self.sort_key_name) | |
| if not thread_id or not sort_key: | |
| continue | |
| if thread_id not in page_thread_data: | |
| page_thread_data[thread_id] = {"checkpoint_item": None, "metadata_items": []} | |
| if sort_key == self.sort_key_checkpoint_value: | |
| page_thread_data[thread_id]["checkpoint_item"] = item | |
| elif isinstance(sort_key, str) and sort_key.startswith(self.sort_key_metadata_prefix): | |
| page_thread_data[thread_id]["metadata_items"].append(item) | |
| # Process each thread found in the page | |
| for thread_id, data in page_thread_data.items(): | |
| if thread_id in seen_thread_ids: | |
| continue | |
| if limit is not None and count >= limit: | |
| return | |
| checkpoint_item = data["checkpoint_item"] | |
| metadata_items = data["metadata_items"] | |
| if not checkpoint_item: | |
| # If only metadata found for a thread_id in this page, skip for now | |
| # We need the main checkpoint item to construct the full tuple | |
| continue | |
| # Deserialize checkpoint | |
| try: | |
| checkpoint_data = checkpoint_item.get("checkpoint") | |
| if isinstance(checkpoint_data, bytes): | |
| checkpoint = self.serializer.loads(checkpoint_data) | |
| elif checkpoint_data is None: | |
| checkpoint = Checkpoint(v=1, ts="", id="", channel_values={}, channel_versions={}, seen_recovery_events=set(), config={}) | |
| else: | |
| print(f"Warning: Invalid checkpoint data type during scan for thread {thread_id}") | |
| continue | |
| except Exception as e: | |
| print(f"Error deserializing checkpoint data during scan for thread {thread_id}: {e}") | |
| continue | |
| # Construct metadata dictionary for this thread | |
| metadata: Dict[str, CheckpointMetadata] = {} | |
| for meta_item in metadata_items: | |
| # --- Use self.sort_key_name --- | |
| sort_key_value = meta_item.get(self.sort_key_name) | |
| if not isinstance(sort_key_value, str): continue | |
| try: | |
| ts = sort_key_value.split(self.sort_key_metadata_prefix, 1)[1] | |
| meta_data_bytes = meta_item.get("metadata") | |
| if isinstance(meta_data_bytes, bytes): | |
| meta_dict = self.serializer.loads(meta_data_bytes) | |
| metadata[ts] = meta_dict if isinstance(meta_dict, dict) else {} | |
| else: | |
| metadata[ts] = {} | |
| except Exception as e: | |
| print(f"Warning: Could not process metadata item {sort_key_value} during scan for thread {thread_id}: {e}") | |
| # --- Apply Post-Retrieval Filter (on the latest checkpoint's metadata) --- | |
| latest_ts = checkpoint.ts if hasattr(checkpoint, 'ts') else None | |
| latest_metadata_entry = metadata.get(latest_ts, {}) if latest_ts else {} | |
| if filter: | |
| match = True | |
| for key, value in filter.items(): | |
| if latest_metadata_entry.get(key) != value: | |
| match = False | |
| break | |
| if not match: | |
| continue # Skip thread if filter doesn't match | |
| # Determine parent config (similar logic as in aget_tuple) | |
| parent_config = None | |
| if latest_ts: | |
| sorted_metadata_ts = sorted(metadata.keys(), reverse=True) | |
| try: | |
| current_checkpoint_idx = sorted_metadata_ts.index(latest_ts) | |
| if current_checkpoint_idx + 1 < len(sorted_metadata_ts): | |
| parent_ts = sorted_metadata_ts[current_checkpoint_idx + 1] | |
| parent_config = metadata.get(parent_ts, {}).get("config") | |
| except ValueError: | |
| pass # ts not found in metadata | |
| # Use the config stored within the latest checkpoint | |
| config = checkpoint.config if hasattr(checkpoint, 'config') else {} | |
| yield CheckpointTuple( | |
| config=config, | |
| checkpoint=checkpoint, | |
| metadata=metadata, | |
| parent_config=parent_config | |
| ) | |
| seen_thread_ids.add(thread_id) | |
| count += 1 | |
| if limit is not None and count >= limit: | |
| return | |
| async def aput( | |
| self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata | |
| ) -> Dict[str, Any]: | |
| """ | |
| Asynchronously saves a checkpoint and its metadata to DynamoDB. | |
| Args: | |
| config: The configuration identifying the thread (must contain 'thread_id'). | |
| checkpoint: The checkpoint object to save. | |
| metadata: The metadata associated with the checkpoint. | |
| Returns: | |
| The original config dictionary used for putting the checkpoint. | |
| """ | |
| thread_id = config["thread_id"] | |
| # Ensure checkpoint has a timestamp, generate if missing (though LangGraph usually provides it) | |
| if not hasattr(checkpoint, 'ts') or not checkpoint.ts: | |
| print(f"Warning: Checkpoint for thread {thread_id} missing timestamp. Generating one.") | |
| checkpoint.ts = datetime.now(timezone.utc).isoformat() | |
| serialized_checkpoint = self.serializer.dumps(checkpoint) | |
| # Ensure metadata includes the config for retrieval consistency | |
| metadata_to_save = {**metadata, "config": config} | |
| serialized_metadata = self.serializer.dumps(metadata_to_save) | |
| # Calculate TTL if configured | |
| ttl_timestamp = self._calculate_ttl() | |
| async with self._get_client() as client: | |
| # Prepare the main checkpoint item | |
| checkpoint_item = { | |
| self.primary_key: thread_id, | |
| # --- Use self.sort_key_name and configured value --- | |
| self.sort_key_name: self.sort_key_checkpoint_value, | |
| "checkpoint": serialized_checkpoint, | |
| "ts": checkpoint.ts, # Store latest timestamp for reference | |
| } | |
| if self.ttl_key and ttl_timestamp is not None: | |
| checkpoint_item[self.ttl_key] = ttl_timestamp | |
| # Prepare the metadata item | |
| metadata_item = { | |
| self.primary_key: thread_id, | |
| # --- Use self.sort_key_name and construct value --- | |
| self.sort_key_name: f"{self.sort_key_metadata_prefix}{checkpoint.ts}", | |
| "metadata": serialized_metadata, | |
| } | |
| if self.ttl_key and ttl_timestamp is not None: | |
| metadata_item[self.ttl_key] = ttl_timestamp | |
| try: | |
| # Use BatchWriteItem for better atomicity and efficiency | |
| await client.batch_write_item( | |
| RequestItems={ | |
| self.table_name: [ | |
| {'PutRequest': {'Item': self._serialize_item(checkpoint_item)}}, | |
| {'PutRequest': {'Item': self._serialize_item(metadata_item)}}, | |
| ] | |
| } | |
| ) | |
| # Check for unprocessed items (optional but recommended for production) | |
| # response = await client.batch_write_item(...) | |
| # unprocessed = response.get('UnprocessedItems', {}).get(self.table_name) | |
| # if unprocessed: | |
| # print(f"Warning: Failed to process some items for thread {thread_id}: {unprocessed}") | |
| # # Implement retry logic if needed | |
| except ClientError as e: | |
| print(f"Error putting checkpoint/metadata for thread {thread_id} using BatchWriteItem: {e}") | |
| # Consider falling back to individual PutItem or re-raising | |
| raise | |
| except Exception as e: # Catch broader exceptions during item preparation | |
| print(f"Unexpected error preparing items for thread {thread_id}: {e}") | |
| raise | |
| return config # Return the config used for the put operation | |
| # Example Usage (requires async environment) | |
| async def main(): | |
| # --- CONFIGURE FOR YOUR TABLE --- | |
| your_table_name = "brie-ml-chatbot-llm-checkpoints" # From your Terraform | |
| your_primary_key = "thread_id" # From your Terraform | |
| your_sort_key_name = "checkpoint_id" # From your Terraform | |
| # --- Choose Sort Key Values --- | |
| # Decide what values you want to store in the 'checkpoint_id' (Sort Key) column | |
| # Option 1: Keep the original library's convention (recommended for less code change) | |
| your_sk_checkpoint_value = "checkpoint" | |
| your_sk_metadata_prefix = "metadata|" | |
| # Option 2: Use checkpoint.id (if it's guaranteed unique per thread & timestamp) | |
| # This would require more significant changes in how 'aget_tuple' and 'alist' find items. | |
| # Let's stick with Option 1 for this example. | |
| print(f"Configuring saver for:") | |
| print(f" Table: {your_table_name}") | |
| print(f" PK: {your_primary_key}") | |
| print(f" SK Name: {your_sort_key_name}") | |
| print(f" SK Checkpoint Value: {your_sk_checkpoint_value}") | |
| print(f" SK Metadata Prefix: {your_sk_metadata_prefix}") | |
| # Instantiate the saver with your specific configuration | |
| saver = AsyncDynamoDBSaver( | |
| table_name=your_table_name, | |
| primary_key=your_primary_key, | |
| sort_key_name=your_sort_key_name, # Specify your sort key name | |
| sort_key_checkpoint_value=your_sk_checkpoint_value, # Specify the value for the main checkpoint | |
| sort_key_metadata_prefix=your_sk_metadata_prefix, # Specify the prefix for metadata items | |
| # endpoint_url="http://localhost:8000" # Uncomment for DynamoDB Local | |
| # ttl_duration=timedelta(days=30) # Optional: Set TTL | |
| ) | |
| # Example config and checkpoint data | |
| thread_config = {"thread_id": "brie-thread-async-02"} | |
| # Generate a unique ID for the checkpoint state itself | |
| checkpoint_state_id = f"cp-{int(time.time())}" # Example ID | |
| checkpoint_ts = datetime.now(timezone.utc).isoformat() | |
| initial_checkpoint = Checkpoint( | |
| v=1, | |
| ts=checkpoint_ts, | |
| id=checkpoint_state_id, # Unique ID for this state | |
| channel_values={"messages": ["hello from async config test"]}, | |
| channel_versions={"messages": 1}, | |
| seen_recovery_events=set(), | |
| config=thread_config, # Config that generated this state | |
| ) | |
| initial_metadata: CheckpointMetadata = { # Use the TypedDict for clarity | |
| "source": "user_input", | |
| "step": 1, | |
| "writes": {"chatbot": {"messages": ["hello from async config test"]}}, | |
| # config is added automatically in aput | |
| } | |
| # --- Put Checkpoint --- | |
| print(f"\nPutting checkpoint for: {thread_config['thread_id']}") | |
| # Pass the config that identifies the thread | |
| await saver.aput(thread_config, initial_checkpoint, initial_metadata) | |
| print("Put successful.") | |
| # --- Get Checkpoint --- | |
| print(f"\nGetting checkpoint tuple for: {thread_config['thread_id']}") | |
| retrieved_tuple = await saver.aget_tuple(thread_config) | |
| if retrieved_tuple: | |
| print("Retrieved Checkpoint:", retrieved_tuple.checkpoint) | |
| print("Retrieved Metadata:", retrieved_tuple.metadata) | |
| print("Retrieved Parent Config:", retrieved_tuple.parent_config) | |
| # Verify the sort key name was used correctly (inspect DynamoDB manually if needed) | |
| else: | |
| print("Checkpoint not found.") | |
| # --- Add another checkpoint --- | |
| await asyncio.sleep(1) # Ensure timestamp changes | |
| next_checkpoint_state_id = f"cp-{int(time.time())}" | |
| next_checkpoint_ts = datetime.now(timezone.utc).isoformat() | |
| next_config = {"thread_id": "brie-thread-async-02"} # Could be the same or updated config | |
| next_checkpoint = Checkpoint( | |
| v=1, | |
| ts=next_checkpoint_ts, | |
| id=next_checkpoint_state_id, | |
| channel_values={"messages": ["hello from async config test", "how are you?"]}, | |
| channel_versions={"messages": 2}, | |
| seen_recovery_events=set(), | |
| config=next_config, # Config that generated this state | |
| ) | |
| next_metadata: CheckpointMetadata = { | |
| "source": "llm_response", | |
| "step": 2, | |
| "writes": {"chatbot": {"messages": ["how are you?"]}}, | |
| } | |
| print(f"\nPutting next checkpoint for: {thread_config['thread_id']}") | |
| await saver.aput(next_config, next_checkpoint, next_metadata) # Use the relevant config | |
| print("Put successful.") | |
| # --- Get Updated Checkpoint --- | |
| print(f"\nGetting updated checkpoint tuple for: {thread_config['thread_id']}") | |
| retrieved_tuple_updated = await saver.aget_tuple(thread_config) # Use the base thread config to get latest | |
| if retrieved_tuple_updated: | |
| print("Updated Checkpoint:", retrieved_tuple_updated.checkpoint) | |
| print("Updated Metadata:", retrieved_tuple_updated.metadata) # Should have two entries now | |
| print("Updated Parent Config:", retrieved_tuple_updated.parent_config) # Should be initial_checkpoint's config | |
| assert retrieved_tuple_updated.checkpoint.ts == next_checkpoint_ts | |
| assert initial_checkpoint.ts in retrieved_tuple_updated.metadata | |
| else: | |
| print("Updated Checkpoint not found.") | |
| # --- List Checkpoints --- | |
| print(f"\nListing checkpoints for: {thread_config['thread_id']}") | |
| async for checkpoint_tuple in saver.alist(thread_config, limit=5): | |
| # Note: The yielded tuple from alist in this impl contains config/metadata | |
| print(f" - Metadata TS: {checkpoint_tuple.config.get('checkpoint_ts')}, Source: {checkpoint_tuple.metadata.get('source')}") | |
| # --- List Checkpoints Before a specific one --- | |
| # Config representing the state *at* the second checkpoint | |
| config_at_second_checkpoint = retrieved_tuple_updated.checkpoint.config if retrieved_tuple_updated else None | |
| if config_at_second_checkpoint: | |
| print(f"\nListing checkpoints before {next_checkpoint_ts}:") | |
| # Use the config associated with the second checkpoint to identify the 'before' point | |
| async for checkpoint_tuple in saver.alist(thread_config, before=config_at_second_checkpoint): | |
| print(f" - Metadata TS: {checkpoint_tuple.config.get('checkpoint_ts')}, Source: {checkpoint_tuple.metadata.get('source')}") | |
| # Should only list the first checkpoint's metadata | |
| else: | |
| print("\nSkipping 'before' test as second checkpoint wasn't retrieved.") | |
| if __name__ == "__main__": | |
| # Ensure you have a running DynamoDB instance (local or AWS) | |
| # and necessary credentials/region configured. | |
| # You might need to create the table using AWS CLI or Terraform matching your definition: | |
| # aws dynamodb create-table \ | |
| # --table-name brie-ml-chatbot-llm-checkpoints \ | |
| # --attribute-definitions AttributeName=thread_id,AttributeType=S AttributeName=checkpoint_id,AttributeType=S \ | |
| # --key-schema AttributeName=thread_id,KeyType=HASH AttributeName=checkpoint_id,KeyType=RANGE \ | |
| # --billing-mode PAY_PER_REQUEST \ | |
| # # --endpoint-url http://localhost:8000 # Optional: for DynamoDB Local | |
| try: | |
| asyncio.run(main()) | |
| except ImportError as e: | |
| print(f"Please install necessary libraries: pip install aiobotocore langgraph boto3. Error: {e}") | |
| except ClientError as e: | |
| error_code = e.response.get('Error', {}).get('Code') | |
| if error_code == 'ResourceNotFoundException': | |
| print(f"\nError: DynamoDB table '{your_table_name}' not found.") # Use configured name | |
| print("Please ensure the table exists and matches the expected schema (PK=thread_id, SK=checkpoint_id).") | |
| elif 'Credentials' in str(e) or error_code == 'UnrecognizedClientException': | |
| print(f"\nError: AWS Credentials not found, invalid, or region misconfigured: {e}") | |
| print("Ensure your AWS credentials (access key, secret key, region) are configured correctly (e.g., via environment variables, ~/.aws/credentials, IAM role).") | |
| elif error_code == 'ValidationException': | |
| print(f"\nError: DynamoDB request validation failed: {e}") | |
| print("This might indicate an issue with the data being sent (e.g., invalid types, missing keys) or table schema mismatch.") | |
| else: | |
| print(f"\nAn AWS ClientError occurred: {e}") | |
| except Exception as e: | |
| import traceback | |
| print(f"\nAn unexpected error occurred: {e}") | |
| print(traceback.format_exc()) | |
| #**************************************************** | |
| ### Example Implementatation langgraph | |
| import asyncio | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.checkpoint.base import Checkpoint | |
| # Assume AsyncDynamoDBSaver class is defined as in the document | |
| # from your_module import AsyncDynamoDBSaver | |
| # 1. Define your graph state | |
| class AgentState(TypedDict): | |
| input: str | |
| output: str | |
| # 2. Define your graph nodes (ensure they are async if doing async work) | |
| async def node_a(state: AgentState): | |
| print("---Executing Node A---") | |
| # Simulate async work | |
| await asyncio.sleep(0.1) | |
| return {"output": f"Output from A based on '{state['input']}'"} | |
| async def node_b(state: AgentState): | |
| print("---Executing Node B---") | |
| await asyncio.sleep(0.1) | |
| return {"output": state['output'] + " | Output from B"} | |
| # 3. Instantiate the Async Checkpointer | |
| # Configure with your table name, region, etc. | |
| # Use endpoint_url for DynamoDB Local if needed | |
| checkpointer = AsyncDynamoDBSaver( | |
| table_name="langgraph_checkpoints", | |
| # aws_region="your-region", # Optional | |
| # endpoint_url="http://localhost:8000" # Optional | |
| ) | |
| # 4. Build the graph, passing the checkpointer | |
| builder = StateGraph(AgentState) | |
| builder.add_node("a", node_a) | |
| builder.add_node("b", node_b) | |
| builder.add_edge(START, "a") | |
| builder.add_edge("a", "b") | |
| builder.add_edge("b", END) | |
| # Pass the single async checkpointer instance here | |
| graph = builder.compile(checkpointer=checkpointer) | |
| # 5. Use the graph asynchronously | |
| async def run_graph(): | |
| thread_config = {"configurable": {"thread_id": "my-thread-async-01"}} | |
| print("Running graph...") | |
| async for event in graph.astream_events( | |
| {"input": "hello async world"}, config=thread_config, version="v1" | |
| ): | |
| kind = event["event"] | |
| if kind == "on_chain_end": | |
| print(f"---Graph Ended with Output---") | |
| print(event["data"]["output"]) | |
| elif kind == "on_checkpoint": | |
| print("---Checkpoint Saved---") | |
| # print(event["data"]) # Can inspect checkpoint data | |
| # You can retrieve the checkpoint later using the same checkpointer instance | |
| print("\nRetrieving final checkpoint:") | |
| final_checkpoint_tuple = await checkpointer.aget_tuple(thread_config["configurable"]) | |
| if final_checkpoint_tuple: | |
| print(final_checkpoint_tuple.checkpoint.channel_values) # Print final state | |
| # Run the async function | |
| # asyncio.run(run_graph()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment