Skip to content

Instantly share code, notes, and snippets.

@STHITAPRAJNAS
Last active April 16, 2025 17:38
Show Gist options
  • Select an option

  • Save STHITAPRAJNAS/8fc764a847578b67a88e1ad16dc26a0b to your computer and use it in GitHub Desktop.

Select an option

Save STHITAPRAJNAS/8fc764a847578b67a88e1ad16dc26a0b to your computer and use it in GitHub Desktop.
dynamodb-langgraph-async
import asyncio
import pickle
import time # Needed for TTL calculation example
from datetime import datetime, timezone, timedelta # Needed for TTL calculation example
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional, Sequence, Tuple, Dict, Any, List, Union, TypedDict # Added TypedDict
import aiobotocore.session
from botocore.exceptions import ClientError
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from langgraph.serde.base import BaseSerializer
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata, CheckpointTuple
# --- Constants for DynamoDB Table Structure (matching the original library defaults) ---
DEFAULT_TABLE_NAME = "langgraph_checkpoints"
DEFAULT_PK = "thread_id"
# --- NEW: Default name for the Sort Key attribute ---
DEFAULT_SK_NAME = "sk"
DEFAULT_SK_CHECKPOINT_VALUE = "checkpoint" # Default value for the main checkpoint sort key
DEFAULT_SK_METADATA_PREFIX = "metadata|" # Default prefix for metadata sort key values
DEFAULT_TTL_KEY = "ttl_timestamp" # Optional TTL attribute
class PickleSerializer(BaseSerializer):
"""Serializer that uses pickle."""
def dumps(self, obj: Any) -> bytes:
return pickle.dumps(obj)
def loads(self, data: bytes) -> Any:
return pickle.loads(data)
# --- Define CheckpointMetadata structure for clarity ---
# Based on typical LangGraph usage
class CheckpointMetadata(TypedDict, total=False):
source: str # The source of the write (e.g., "input", "loop", "update")
step: int # The step number
writes: Dict[str, Any] # The writes made by the step
score: Optional[int] # Optional score
config: Dict[str, Any] # The config associated with the metadata entry
class AsyncDynamoDBSaver(BaseCheckpointSaver):
"""
An asynchronous checkpoint saver that stores checkpoints in DynamoDB,
allowing configuration of the sort key attribute name.
Args:
table_name (str): The name of the DynamoDB table. Defaults to "langgraph_checkpoints".
primary_key (str): The name of the partition key attribute. Defaults to "thread_id".
sort_key_name (str): The name of the sort key attribute. Defaults to "sk".
sort_key_checkpoint_value (str): The sort key value for the main checkpoint data. Defaults to "checkpoint".
sort_key_metadata_prefix (str): The prefix for sort key values storing metadata. Defaults to "metadata|".
ttl_key (Optional[str]): The name of the attribute to use for TTL. If None, TTL is not used. Defaults to "ttl_timestamp".
ttl_duration (Optional[timedelta]): Duration for TTL. If set, items expire after this duration from write time. Defaults to None.
serializer (Optional[BaseSerializer]): The serializer to use for checkpoint data. Defaults to PickleSerializer.
aws_region (Optional[str]): AWS region name. If None, uses default from environment/config.
aws_access_key_id (Optional[str]): AWS access key ID. If None, uses default from environment/config.
aws_secret_access_key (Optional[str]): AWS secret access key. If None, uses default from environment/config.
endpoint_url (Optional[str]): Custom DynamoDB endpoint URL (e.g., for DynamoDB Local).
"""
def __init__(
self,
*,
table_name: str = DEFAULT_TABLE_NAME,
primary_key: str = DEFAULT_PK,
# --- NEW: Parameter for Sort Key Name ---
sort_key_name: str = DEFAULT_SK_NAME,
sort_key_checkpoint_value: str = DEFAULT_SK_CHECKPOINT_VALUE,
sort_key_metadata_prefix: str = DEFAULT_SK_METADATA_PREFIX,
ttl_key: Optional[str] = DEFAULT_TTL_KEY,
ttl_duration: Optional[timedelta] = None, # Added TTL duration
serializer: Optional[BaseSerializer] = None,
aws_region: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
endpoint_url: Optional[str] = None,
):
super().__init__(serializer=serializer or PickleSerializer())
self.table_name = table_name
self.primary_key = primary_key
# --- Store the Sort Key Name ---
self.sort_key_name = sort_key_name
self.sort_key_checkpoint_value = sort_key_checkpoint_value
self.sort_key_metadata_prefix = sort_key_metadata_prefix
self.ttl_key = ttl_key
self.ttl_duration = ttl_duration # Store TTL duration
self.serializer = self.serde
# --- aiobotocore Session and Client Setup ---
self.session = aiobotocore.session.get_session()
self.client_config = {
"service_name": "dynamodb",
"region_name": aws_region,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"endpoint_url": endpoint_url,
}
# --- DynamoDB Type Serializer/Deserializer ---
self._type_serializer = TypeSerializer()
self._type_deserializer = TypeDeserializer()
@asynccontextmanager
async def _get_client(self) -> AsyncGenerator[Any, None]:
"""Provides an asynchronous DynamoDB client context."""
async with self.session.create_client(**self.client_config) as client:
yield client
def _serialize_item(self, item: Dict[str, Any]) -> Dict[str, Dict]:
"""Serializes a Python dict into DynamoDB attribute value format."""
# Filter out None values before serialization, as DynamoDB doesn't like them
# unless specifically handled (e.g., Null=True which isn't default for TypeSerializer)
filtered_item = {k: v for k, v in item.items() if v is not None}
return {k: self._type_serializer.serialize(v) for k, v in filtered_item.items()}
def _deserialize_item(self, item: Dict[str, Dict]) -> Dict[str, Any]:
"""Deserializes a DynamoDB item into a Python dict."""
return {k: self._type_deserializer.deserialize(v) for k, v in item.items()}
def _calculate_ttl(self) -> Optional[int]:
"""Calculates the TTL timestamp if ttl_key and ttl_duration are set."""
if self.ttl_key and self.ttl_duration:
expiry_time = datetime.now(timezone.utc) + self.ttl_duration
return int(expiry_time.timestamp())
return None
async def aget_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]:
"""
Asynchronously retrieves a checkpoint tuple (checkpoint, metadata, parent_config)
for a given thread configuration.
Args:
config: The configuration identifying the thread (must contain 'thread_id').
Returns:
An optional CheckpointTuple if found, otherwise None.
"""
thread_id = config["thread_id"]
async with self._get_client() as client:
try:
# Query for both checkpoint and metadata items for the thread_id
response = await client.query(
TableName=self.table_name,
KeyConditionExpression=f"{self.primary_key} = :pk",
ExpressionAttributeValues={
":pk": self._type_serializer.serialize(thread_id),
},
ConsistentRead=True, # Ensure we get the latest data
)
except ClientError as e:
print(f"Error querying DynamoDB for thread {thread_id}: {e}")
return None
items = [self._deserialize_item(item) for item in response.get("Items", [])]
checkpoint_item = None
metadata_items = []
# Separate checkpoint and metadata items using the configured sort_key_name
for item in items:
# --- Use self.sort_key_name ---
sort_key_value = item.get(self.sort_key_name)
if sort_key_value == self.sort_key_checkpoint_value:
checkpoint_item = item
elif isinstance(sort_key_value, str) and sort_key_value.startswith(self.sort_key_metadata_prefix):
metadata_items.append(item)
if not checkpoint_item:
print(f"No checkpoint item found for thread {thread_id} with SK value {self.sort_key_checkpoint_value}")
return None # No checkpoint found for this thread_id
# Deserialize checkpoint data
try:
checkpoint_data = checkpoint_item.get("checkpoint")
if isinstance(checkpoint_data, bytes):
checkpoint = self.serializer.loads(checkpoint_data)
elif checkpoint_data is None:
# Create a default empty checkpoint if data is None
checkpoint = Checkpoint(v=1, ts="", id="", channel_values={}, channel_versions={}, seen_recovery_events=set(), config={})
print(f"Warning: Checkpoint data is None for thread {thread_id}. Using default empty checkpoint.")
else:
print(f"Warning: Unexpected checkpoint data type: {type(checkpoint_data)} for thread {thread_id}")
return None
except (pickle.UnpicklingError, TypeError, EOFError, AttributeError, ModuleNotFoundError) as e:
# Added more specific exceptions
print(f"Error deserializing checkpoint data for thread {thread_id}: {e}")
return None # Corrupted data
# Construct metadata dictionary
metadata: Dict[str, CheckpointMetadata] = {}
for meta_item in metadata_items:
# --- Use self.sort_key_name ---
sort_key_value = meta_item.get(self.sort_key_name)
if not isinstance(sort_key_value, str): continue # Skip if sort key is missing or wrong type
try:
# Extract ts from the sort key value
ts = sort_key_value.split(self.sort_key_metadata_prefix, 1)[1]
meta_data_bytes = meta_item.get("metadata")
if isinstance(meta_data_bytes, bytes):
meta_dict = self.serializer.loads(meta_data_bytes)
# Ensure it's a dictionary, default to empty if deserialization fails or type is wrong
metadata[ts] = meta_dict if isinstance(meta_dict, dict) else {}
else:
# If metadata is not stored as bytes or is missing, use an empty dict
metadata[ts] = {}
except (IndexError, pickle.UnpicklingError, TypeError, EOFError, AttributeError, ModuleNotFoundError) as e:
print(f"Warning: Could not process metadata item {sort_key_value} for thread {thread_id}: {e}")
# Assign empty dict for this potentially corrupted/malformed metadata entry
ts_key = sort_key_value.split(self.sort_key_metadata_prefix, 1)[-1]
if ts_key != sort_key_value: # Check if split actually happened
metadata[ts_key] = {}
# Find the parent config
parent_config = None
# Ensure checkpoint is valid and has a timestamp before proceeding
if checkpoint and hasattr(checkpoint, 'ts') and checkpoint.ts:
sorted_metadata_ts = sorted(metadata.keys(), reverse=True)
try:
# Find the index of the current checkpoint's timestamp in the sorted list
current_checkpoint_idx = sorted_metadata_ts.index(checkpoint.ts)
# The parent is the next one in the sorted list (chronologically older)
if current_checkpoint_idx + 1 < len(sorted_metadata_ts):
parent_ts = sorted_metadata_ts[current_checkpoint_idx + 1]
# Get the config from the parent's metadata entry
parent_config = metadata.get(parent_ts, {}).get("config")
except ValueError:
# This might happen if the checkpoint.ts is not found in metadata keys
# This could indicate an issue with saving or data consistency
print(f"Warning: Checkpoint timestamp {checkpoint.ts} not found in metadata keys for thread {thread_id}.")
# As a fallback, consider the config from the chronologically previous metadata entry, if any
if sorted_metadata_ts:
# Find the first timestamp older than the current checkpoint's ts
older_ts = [ts for ts in sorted_metadata_ts if ts < checkpoint.ts]
if older_ts:
parent_ts = older_ts[0] # The most recent older timestamp
parent_config = metadata.get(parent_ts, {}).get("config")
elif sorted_metadata_ts[0] != checkpoint.ts:
# Fallback: if the latest metadata isn't the current one, maybe it's the parent?
# This logic might be flawed depending on exact saving order.
parent_config = metadata.get(sorted_metadata_ts[0], {}).get("config")
# Ensure the returned config in the tuple is the one from the input args
# The checkpoint.config might be stale if not updated correctly during 'put'
return CheckpointTuple(config=config, checkpoint=checkpoint, metadata=metadata, parent_config=parent_config)
async def alist(
self,
config: Optional[Dict[str, Any]],
*,
filter: Optional[Dict[str, Any]] = None,
before: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
) -> AsyncGenerator[CheckpointTuple, None]:
"""
Asynchronously lists checkpoints for a given thread configuration,
optionally filtered and limited.
Args:
config: The configuration identifying the thread (must contain 'thread_id').
If None, attempts to scan the table (use with caution).
filter: Optional dictionary for filtering based on metadata attributes
(applied post-retrieval, potentially inefficient).
before: Optional config to list checkpoints strictly before the one specified
by this config's timestamp.
limit: The maximum number of checkpoints to return.
Yields:
CheckpointTuple objects containing config and metadata (checkpoint is None).
"""
if config is None:
# Delegate to scan helper if no specific thread config is provided
print("Warning: Listing checkpoints without a thread_id will perform a DynamoDB Scan, which can be inefficient and costly on large tables.")
# Need to pass self to the generator correctly
async for item in self._scan_all_checkpoints(filter=filter, limit=limit):
yield item
return
thread_id = config.get("thread_id")
if not thread_id:
raise ValueError("Configuration must include 'thread_id' to list checkpoints for a specific thread.")
async with self._get_client() as client:
# --- Use self.sort_key_name in KeyConditionExpression ---
key_condition = f"{self.primary_key} = :pk AND begins_with({self.sort_key_name}, :sk_prefix)"
expression_values = {
":pk": self._type_serializer.serialize(thread_id),
":sk_prefix": self._type_serializer.serialize(self.sort_key_metadata_prefix),
}
# --- Filtering Logic ---
if before and before.get("thread_id") == thread_id:
before_tuple = await self.aget_tuple(before)
if before_tuple and before_tuple.checkpoint and hasattr(before_tuple.checkpoint, 'ts') and before_tuple.checkpoint.ts:
before_ts = before_tuple.checkpoint.ts
# --- Use self.sort_key_name ---
key_condition = f"{self.primary_key} = :pk AND {self.sort_key_name} < :sk_before"
# Use the full metadata sort key value for comparison
expression_values[":sk_before"] = self._type_serializer.serialize(f"{self.sort_key_metadata_prefix}{before_ts}")
# Remove the sk_prefix as it's incompatible with the '<' condition on the full key
expression_values.pop(":sk_prefix", None)
else:
print(f"Warning: 'before' config {before} did not resolve to a valid checkpoint timestamp. Ignoring 'before' filter.")
query_kwargs = {
"TableName": self.table_name,
"KeyConditionExpression": key_condition,
"ExpressionAttributeValues": expression_values,
"ScanIndexForward": False, # List newest first
"ConsistentRead": True,
}
# --- Pagination and Limit ---
paginator = client.get_paginator("query")
pages = paginator.paginate(**query_kwargs)
count = 0
async for page in pages:
items = [self._deserialize_item(item) for item in page.get("Items", [])]
# Already sorted by DynamoDB (ScanIndexForward=False)
for meta_item in items:
if limit is not None and count >= limit:
return
# --- Use self.sort_key_name ---
sort_key_value = meta_item.get(self.sort_key_name)
if not isinstance(sort_key_value, str): continue # Skip malformed
try:
ts = sort_key_value.split(self.sort_key_metadata_prefix, 1)[1]
meta_data_bytes = meta_item.get("metadata")
if isinstance(meta_data_bytes, bytes):
metadata_entry: CheckpointMetadata = self.serializer.loads(meta_data_bytes)
if not isinstance(metadata_entry, dict): metadata_entry = {}
else:
metadata_entry = {}
# --- Apply Post-Retrieval Filter ---
if filter:
match = True
for key, value in filter.items():
if metadata_entry.get(key) != value:
match = False
break
if not match:
continue
# Construct a config representing this specific point in time
# Use the config stored within the metadata if available, otherwise fallback
list_item_config = metadata_entry.get("config", config)
# Add the timestamp to the config for clarity
list_item_config = {**list_item_config, "checkpoint_ts": ts}
# Yield config/metadata only, as per BaseCheckpointSaver.alist expectation
yield CheckpointTuple(config=list_item_config, checkpoint=None, metadata=metadata_entry, parent_config=None)
count += 1
except (IndexError, pickle.UnpicklingError, TypeError, EOFError, AttributeError, ModuleNotFoundError) as e:
print(f"Warning: Could not process metadata item {sort_key_value} during list for thread {thread_id}: {e}")
continue
if limit is not None and count >= limit:
break
async def _scan_all_checkpoints(
self,
*,
filter: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None
) -> AsyncGenerator[CheckpointTuple, None]:
""" Helper to scan the entire table. Yields latest CheckpointTuple per thread_id found. Use with caution. """
async with self._get_client() as client:
paginator = client.get_paginator("scan")
scan_kwargs = {"TableName": self.table_name}
count = 0
seen_thread_ids = set()
# Store items per thread_id to process after scanning each page
page_thread_data: Dict[str, Dict[str, Any]] = {}
async for page in paginator.paginate(**scan_kwargs):
items = [self._deserialize_item(item) for item in page.get("Items", [])]
page_thread_data.clear()
# Group items by thread_id found in the current page
for item in items:
thread_id = item.get(self.primary_key)
# --- Use self.sort_key_name ---
sort_key = item.get(self.sort_key_name)
if not thread_id or not sort_key:
continue
if thread_id not in page_thread_data:
page_thread_data[thread_id] = {"checkpoint_item": None, "metadata_items": []}
if sort_key == self.sort_key_checkpoint_value:
page_thread_data[thread_id]["checkpoint_item"] = item
elif isinstance(sort_key, str) and sort_key.startswith(self.sort_key_metadata_prefix):
page_thread_data[thread_id]["metadata_items"].append(item)
# Process each thread found in the page
for thread_id, data in page_thread_data.items():
if thread_id in seen_thread_ids:
continue
if limit is not None and count >= limit:
return
checkpoint_item = data["checkpoint_item"]
metadata_items = data["metadata_items"]
if not checkpoint_item:
# If only metadata found for a thread_id in this page, skip for now
# We need the main checkpoint item to construct the full tuple
continue
# Deserialize checkpoint
try:
checkpoint_data = checkpoint_item.get("checkpoint")
if isinstance(checkpoint_data, bytes):
checkpoint = self.serializer.loads(checkpoint_data)
elif checkpoint_data is None:
checkpoint = Checkpoint(v=1, ts="", id="", channel_values={}, channel_versions={}, seen_recovery_events=set(), config={})
else:
print(f"Warning: Invalid checkpoint data type during scan for thread {thread_id}")
continue
except Exception as e:
print(f"Error deserializing checkpoint data during scan for thread {thread_id}: {e}")
continue
# Construct metadata dictionary for this thread
metadata: Dict[str, CheckpointMetadata] = {}
for meta_item in metadata_items:
# --- Use self.sort_key_name ---
sort_key_value = meta_item.get(self.sort_key_name)
if not isinstance(sort_key_value, str): continue
try:
ts = sort_key_value.split(self.sort_key_metadata_prefix, 1)[1]
meta_data_bytes = meta_item.get("metadata")
if isinstance(meta_data_bytes, bytes):
meta_dict = self.serializer.loads(meta_data_bytes)
metadata[ts] = meta_dict if isinstance(meta_dict, dict) else {}
else:
metadata[ts] = {}
except Exception as e:
print(f"Warning: Could not process metadata item {sort_key_value} during scan for thread {thread_id}: {e}")
# --- Apply Post-Retrieval Filter (on the latest checkpoint's metadata) ---
latest_ts = checkpoint.ts if hasattr(checkpoint, 'ts') else None
latest_metadata_entry = metadata.get(latest_ts, {}) if latest_ts else {}
if filter:
match = True
for key, value in filter.items():
if latest_metadata_entry.get(key) != value:
match = False
break
if not match:
continue # Skip thread if filter doesn't match
# Determine parent config (similar logic as in aget_tuple)
parent_config = None
if latest_ts:
sorted_metadata_ts = sorted(metadata.keys(), reverse=True)
try:
current_checkpoint_idx = sorted_metadata_ts.index(latest_ts)
if current_checkpoint_idx + 1 < len(sorted_metadata_ts):
parent_ts = sorted_metadata_ts[current_checkpoint_idx + 1]
parent_config = metadata.get(parent_ts, {}).get("config")
except ValueError:
pass # ts not found in metadata
# Use the config stored within the latest checkpoint
config = checkpoint.config if hasattr(checkpoint, 'config') else {}
yield CheckpointTuple(
config=config,
checkpoint=checkpoint,
metadata=metadata,
parent_config=parent_config
)
seen_thread_ids.add(thread_id)
count += 1
if limit is not None and count >= limit:
return
async def aput(
self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata
) -> Dict[str, Any]:
"""
Asynchronously saves a checkpoint and its metadata to DynamoDB.
Args:
config: The configuration identifying the thread (must contain 'thread_id').
checkpoint: The checkpoint object to save.
metadata: The metadata associated with the checkpoint.
Returns:
The original config dictionary used for putting the checkpoint.
"""
thread_id = config["thread_id"]
# Ensure checkpoint has a timestamp, generate if missing (though LangGraph usually provides it)
if not hasattr(checkpoint, 'ts') or not checkpoint.ts:
print(f"Warning: Checkpoint for thread {thread_id} missing timestamp. Generating one.")
checkpoint.ts = datetime.now(timezone.utc).isoformat()
serialized_checkpoint = self.serializer.dumps(checkpoint)
# Ensure metadata includes the config for retrieval consistency
metadata_to_save = {**metadata, "config": config}
serialized_metadata = self.serializer.dumps(metadata_to_save)
# Calculate TTL if configured
ttl_timestamp = self._calculate_ttl()
async with self._get_client() as client:
# Prepare the main checkpoint item
checkpoint_item = {
self.primary_key: thread_id,
# --- Use self.sort_key_name and configured value ---
self.sort_key_name: self.sort_key_checkpoint_value,
"checkpoint": serialized_checkpoint,
"ts": checkpoint.ts, # Store latest timestamp for reference
}
if self.ttl_key and ttl_timestamp is not None:
checkpoint_item[self.ttl_key] = ttl_timestamp
# Prepare the metadata item
metadata_item = {
self.primary_key: thread_id,
# --- Use self.sort_key_name and construct value ---
self.sort_key_name: f"{self.sort_key_metadata_prefix}{checkpoint.ts}",
"metadata": serialized_metadata,
}
if self.ttl_key and ttl_timestamp is not None:
metadata_item[self.ttl_key] = ttl_timestamp
try:
# Use BatchWriteItem for better atomicity and efficiency
await client.batch_write_item(
RequestItems={
self.table_name: [
{'PutRequest': {'Item': self._serialize_item(checkpoint_item)}},
{'PutRequest': {'Item': self._serialize_item(metadata_item)}},
]
}
)
# Check for unprocessed items (optional but recommended for production)
# response = await client.batch_write_item(...)
# unprocessed = response.get('UnprocessedItems', {}).get(self.table_name)
# if unprocessed:
# print(f"Warning: Failed to process some items for thread {thread_id}: {unprocessed}")
# # Implement retry logic if needed
except ClientError as e:
print(f"Error putting checkpoint/metadata for thread {thread_id} using BatchWriteItem: {e}")
# Consider falling back to individual PutItem or re-raising
raise
except Exception as e: # Catch broader exceptions during item preparation
print(f"Unexpected error preparing items for thread {thread_id}: {e}")
raise
return config # Return the config used for the put operation
# Example Usage (requires async environment)
async def main():
# --- CONFIGURE FOR YOUR TABLE ---
your_table_name = "brie-ml-chatbot-llm-checkpoints" # From your Terraform
your_primary_key = "thread_id" # From your Terraform
your_sort_key_name = "checkpoint_id" # From your Terraform
# --- Choose Sort Key Values ---
# Decide what values you want to store in the 'checkpoint_id' (Sort Key) column
# Option 1: Keep the original library's convention (recommended for less code change)
your_sk_checkpoint_value = "checkpoint"
your_sk_metadata_prefix = "metadata|"
# Option 2: Use checkpoint.id (if it's guaranteed unique per thread & timestamp)
# This would require more significant changes in how 'aget_tuple' and 'alist' find items.
# Let's stick with Option 1 for this example.
print(f"Configuring saver for:")
print(f" Table: {your_table_name}")
print(f" PK: {your_primary_key}")
print(f" SK Name: {your_sort_key_name}")
print(f" SK Checkpoint Value: {your_sk_checkpoint_value}")
print(f" SK Metadata Prefix: {your_sk_metadata_prefix}")
# Instantiate the saver with your specific configuration
saver = AsyncDynamoDBSaver(
table_name=your_table_name,
primary_key=your_primary_key,
sort_key_name=your_sort_key_name, # Specify your sort key name
sort_key_checkpoint_value=your_sk_checkpoint_value, # Specify the value for the main checkpoint
sort_key_metadata_prefix=your_sk_metadata_prefix, # Specify the prefix for metadata items
# endpoint_url="http://localhost:8000" # Uncomment for DynamoDB Local
# ttl_duration=timedelta(days=30) # Optional: Set TTL
)
# Example config and checkpoint data
thread_config = {"thread_id": "brie-thread-async-02"}
# Generate a unique ID for the checkpoint state itself
checkpoint_state_id = f"cp-{int(time.time())}" # Example ID
checkpoint_ts = datetime.now(timezone.utc).isoformat()
initial_checkpoint = Checkpoint(
v=1,
ts=checkpoint_ts,
id=checkpoint_state_id, # Unique ID for this state
channel_values={"messages": ["hello from async config test"]},
channel_versions={"messages": 1},
seen_recovery_events=set(),
config=thread_config, # Config that generated this state
)
initial_metadata: CheckpointMetadata = { # Use the TypedDict for clarity
"source": "user_input",
"step": 1,
"writes": {"chatbot": {"messages": ["hello from async config test"]}},
# config is added automatically in aput
}
# --- Put Checkpoint ---
print(f"\nPutting checkpoint for: {thread_config['thread_id']}")
# Pass the config that identifies the thread
await saver.aput(thread_config, initial_checkpoint, initial_metadata)
print("Put successful.")
# --- Get Checkpoint ---
print(f"\nGetting checkpoint tuple for: {thread_config['thread_id']}")
retrieved_tuple = await saver.aget_tuple(thread_config)
if retrieved_tuple:
print("Retrieved Checkpoint:", retrieved_tuple.checkpoint)
print("Retrieved Metadata:", retrieved_tuple.metadata)
print("Retrieved Parent Config:", retrieved_tuple.parent_config)
# Verify the sort key name was used correctly (inspect DynamoDB manually if needed)
else:
print("Checkpoint not found.")
# --- Add another checkpoint ---
await asyncio.sleep(1) # Ensure timestamp changes
next_checkpoint_state_id = f"cp-{int(time.time())}"
next_checkpoint_ts = datetime.now(timezone.utc).isoformat()
next_config = {"thread_id": "brie-thread-async-02"} # Could be the same or updated config
next_checkpoint = Checkpoint(
v=1,
ts=next_checkpoint_ts,
id=next_checkpoint_state_id,
channel_values={"messages": ["hello from async config test", "how are you?"]},
channel_versions={"messages": 2},
seen_recovery_events=set(),
config=next_config, # Config that generated this state
)
next_metadata: CheckpointMetadata = {
"source": "llm_response",
"step": 2,
"writes": {"chatbot": {"messages": ["how are you?"]}},
}
print(f"\nPutting next checkpoint for: {thread_config['thread_id']}")
await saver.aput(next_config, next_checkpoint, next_metadata) # Use the relevant config
print("Put successful.")
# --- Get Updated Checkpoint ---
print(f"\nGetting updated checkpoint tuple for: {thread_config['thread_id']}")
retrieved_tuple_updated = await saver.aget_tuple(thread_config) # Use the base thread config to get latest
if retrieved_tuple_updated:
print("Updated Checkpoint:", retrieved_tuple_updated.checkpoint)
print("Updated Metadata:", retrieved_tuple_updated.metadata) # Should have two entries now
print("Updated Parent Config:", retrieved_tuple_updated.parent_config) # Should be initial_checkpoint's config
assert retrieved_tuple_updated.checkpoint.ts == next_checkpoint_ts
assert initial_checkpoint.ts in retrieved_tuple_updated.metadata
else:
print("Updated Checkpoint not found.")
# --- List Checkpoints ---
print(f"\nListing checkpoints for: {thread_config['thread_id']}")
async for checkpoint_tuple in saver.alist(thread_config, limit=5):
# Note: The yielded tuple from alist in this impl contains config/metadata
print(f" - Metadata TS: {checkpoint_tuple.config.get('checkpoint_ts')}, Source: {checkpoint_tuple.metadata.get('source')}")
# --- List Checkpoints Before a specific one ---
# Config representing the state *at* the second checkpoint
config_at_second_checkpoint = retrieved_tuple_updated.checkpoint.config if retrieved_tuple_updated else None
if config_at_second_checkpoint:
print(f"\nListing checkpoints before {next_checkpoint_ts}:")
# Use the config associated with the second checkpoint to identify the 'before' point
async for checkpoint_tuple in saver.alist(thread_config, before=config_at_second_checkpoint):
print(f" - Metadata TS: {checkpoint_tuple.config.get('checkpoint_ts')}, Source: {checkpoint_tuple.metadata.get('source')}")
# Should only list the first checkpoint's metadata
else:
print("\nSkipping 'before' test as second checkpoint wasn't retrieved.")
if __name__ == "__main__":
# Ensure you have a running DynamoDB instance (local or AWS)
# and necessary credentials/region configured.
# You might need to create the table using AWS CLI or Terraform matching your definition:
# aws dynamodb create-table \
# --table-name brie-ml-chatbot-llm-checkpoints \
# --attribute-definitions AttributeName=thread_id,AttributeType=S AttributeName=checkpoint_id,AttributeType=S \
# --key-schema AttributeName=thread_id,KeyType=HASH AttributeName=checkpoint_id,KeyType=RANGE \
# --billing-mode PAY_PER_REQUEST \
# # --endpoint-url http://localhost:8000 # Optional: for DynamoDB Local
try:
asyncio.run(main())
except ImportError as e:
print(f"Please install necessary libraries: pip install aiobotocore langgraph boto3. Error: {e}")
except ClientError as e:
error_code = e.response.get('Error', {}).get('Code')
if error_code == 'ResourceNotFoundException':
print(f"\nError: DynamoDB table '{your_table_name}' not found.") # Use configured name
print("Please ensure the table exists and matches the expected schema (PK=thread_id, SK=checkpoint_id).")
elif 'Credentials' in str(e) or error_code == 'UnrecognizedClientException':
print(f"\nError: AWS Credentials not found, invalid, or region misconfigured: {e}")
print("Ensure your AWS credentials (access key, secret key, region) are configured correctly (e.g., via environment variables, ~/.aws/credentials, IAM role).")
elif error_code == 'ValidationException':
print(f"\nError: DynamoDB request validation failed: {e}")
print("This might indicate an issue with the data being sent (e.g., invalid types, missing keys) or table schema mismatch.")
else:
print(f"\nAn AWS ClientError occurred: {e}")
except Exception as e:
import traceback
print(f"\nAn unexpected error occurred: {e}")
print(traceback.format_exc())
#****************************************************
### Example Implementatation langgraph
import asyncio
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.base import Checkpoint
# Assume AsyncDynamoDBSaver class is defined as in the document
# from your_module import AsyncDynamoDBSaver
# 1. Define your graph state
class AgentState(TypedDict):
input: str
output: str
# 2. Define your graph nodes (ensure they are async if doing async work)
async def node_a(state: AgentState):
print("---Executing Node A---")
# Simulate async work
await asyncio.sleep(0.1)
return {"output": f"Output from A based on '{state['input']}'"}
async def node_b(state: AgentState):
print("---Executing Node B---")
await asyncio.sleep(0.1)
return {"output": state['output'] + " | Output from B"}
# 3. Instantiate the Async Checkpointer
# Configure with your table name, region, etc.
# Use endpoint_url for DynamoDB Local if needed
checkpointer = AsyncDynamoDBSaver(
table_name="langgraph_checkpoints",
# aws_region="your-region", # Optional
# endpoint_url="http://localhost:8000" # Optional
)
# 4. Build the graph, passing the checkpointer
builder = StateGraph(AgentState)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_edge(START, "a")
builder.add_edge("a", "b")
builder.add_edge("b", END)
# Pass the single async checkpointer instance here
graph = builder.compile(checkpointer=checkpointer)
# 5. Use the graph asynchronously
async def run_graph():
thread_config = {"configurable": {"thread_id": "my-thread-async-01"}}
print("Running graph...")
async for event in graph.astream_events(
{"input": "hello async world"}, config=thread_config, version="v1"
):
kind = event["event"]
if kind == "on_chain_end":
print(f"---Graph Ended with Output---")
print(event["data"]["output"])
elif kind == "on_checkpoint":
print("---Checkpoint Saved---")
# print(event["data"]) # Can inspect checkpoint data
# You can retrieve the checkpoint later using the same checkpointer instance
print("\nRetrieving final checkpoint:")
final_checkpoint_tuple = await checkpointer.aget_tuple(thread_config["configurable"])
if final_checkpoint_tuple:
print(final_checkpoint_tuple.checkpoint.channel_values) # Print final state
# Run the async function
# asyncio.run(run_graph())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment