Skip to content

Instantly share code, notes, and snippets.

@Finndersen
Created January 19, 2026 12:19
Show Gist options
  • Select an option

  • Save Finndersen/60568a931109a022c757378cfb4a308b to your computer and use it in GitHub Desktop.

Select an option

Save Finndersen/60568a931109a022c757378cfb4a308b to your computer and use it in GitHub Desktop.
Claude Agent SDK wrapper class that provides a convenient interface over ClaudeSDKClient and supports concurrent tool calling
"""Claude Code client using claude-agent-sdk for Claude Code CLI integration."""
import asyncio
from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass
from typing import Any, TypedDict
import logfire
from claude_agent_sdk import (
AssistantMessage,
ClaudeAgentOptions,
ClaudeSDKClient,
Message,
ResultMessage,
ToolUseBlock,
create_sdk_mcp_server,
tool,
)
from claude_agent_sdk.types import PermissionMode, SystemPromptPreset
from pydantic_ai import Tool
from pydantic_core import ValidationError
from .utils import describe_message, load_env_from_settings
@dataclass
class ClaudeCodeResult:
"""Result from a Claude Code client run."""
text_content: str
result_message: ResultMessage
session_id: str
class ClaudeToolTextBlock(TypedDict):
"""Text block structure expected by Claude SDK tools."""
type: str # Should be "text"
text: str
class ClaudeToolContent(TypedDict):
"""Content structure expected by Claude SDK tool responses."""
content: list[ClaudeToolTextBlock]
class ClaudeClient:
"""Low-level client wrapping ClaudeSDKClient for Claude Code CLI integration.
This client provides a minimal interface to Claude Code's capabilities.
Supports:
- Session resumption for continuing previous conversations
- Custom tool registration via PydanticAI Tool instances
- Custom system prompts
- Tool filtering via allowed_builtin_tools
- Streaming and non-streaming execution modes
- Parallel tool execution for concurrent tool calls
"""
def __init__(
self,
session_id: str | None = None,
system_prompt: str | None = None,
include_builtin_system_prompt: bool = False,
tools: list[Tool] | None = None,
allowed_builtin_tools: list[str] | None = None,
model: str | None = None,
cwd: str | None = None,
plan_mode: bool = False,
load_settings: bool = True,
enable_concurrent_execution: bool = True,
):
"""Initialize Claude Code client.
Args:
session_id: Optional session ID to resume a previous conversation
system_prompt: Optional system prompt to include in all runs
include_builtin_system_prompt: Whether to include the built-in system prompt
tools: Optional list of PydanticAI Tool instances to expose as tools.
allowed_builtin_tools: Optional list of allowed tool names (e.g., ["Read", "Bash", "Grep"]).
model: Optional model to use (e.g., "claude-sonnet-4-5-20250929")
cwd: Optional working directory for Claude Code operations
load_settings: Whether to load local, project and user-level .settings.json and CLAUDE.md files
enable_concurrent_execution: Enable concurrent execution of multiple tool calls
"""
self.session_id = session_id
self._tools = tools or []
self._enable_concurrent_execution = enable_concurrent_execution
# Concurrent execution tracking
self._tool_execution_cache: dict[str, asyncio.Future[ClaudeToolContent]] = {}
self._tool_execution_queue: asyncio.Queue[tuple[str, str]] = asyncio.Queue()
# Build MCP servers from custom tools if provided
if tools:
mcp_servers, custom_tool_names = self._build_mcp_servers(tools)
else:
mcp_servers = None
custom_tool_names = []
# Combine allowed_tools with custom tool names
all_allowed_tools = allowed_builtin_tools or []
if custom_tool_names:
all_allowed_tools += custom_tool_names
# Load environment variables from user settings
env_vars = load_env_from_settings()
# Set model name when using AWS Bedrock
if model and env_vars.get("CLAUDE_CODE_USE_BEDROCK") == "1":
region_prefix = env_vars.get("AWS_REGION", "us-west-1").split("-")[0]
model = f"{region_prefix}.anthropic.{model}-v1:0"
# Initialize ClaudeAgentOptions directly
permission_mode: PermissionMode
if plan_mode:
permission_mode = "plan"
elif "Write" in all_allowed_tools:
permission_mode = "acceptEdits"
else:
permission_mode = "default"
self.options = ClaudeAgentOptions(
resume=session_id,
system_prompt=self._build_system_prompt(system_prompt, include_builtin_system_prompt),
allowed_tools=all_allowed_tools,
model=model,
cwd=cwd,
mcp_servers=mcp_servers,
permission_mode=permission_mode,
setting_sources=["local", "project", "user"] if load_settings else None,
env=env_vars,
)
def _build_system_prompt(
self, system_prompt: str | None, include_builtin_system_prompt: bool
) -> SystemPromptPreset | str | None:
if include_builtin_system_prompt:
if system_prompt:
return SystemPromptPreset(type="preset", preset="claude_code", append=system_prompt)
else:
return SystemPromptPreset(type="preset", preset="claude_code")
elif system_prompt:
return system_prompt
else:
return None
def _build_mcp_servers(
self,
tools: list[Tool],
) -> tuple[dict[str, Any], list[str]]:
"""Build MCP servers from PydanticAI Tool instances.
Args:
tools: List of PydanticAI Tool instances
Returns:
Tuple of (mcp_servers dict, list of custom tool names)
"""
# Wrap tools with SDK's @tool decorator
sdk_tools = []
custom_tool_names = []
mcp_name = "builtin_tools"
for pydantic_tool in tools:
# Extract metadata from PydanticAI Tool's function_schema
tool_name = pydantic_tool.name
# Create wrapper that converts the function to Claude Code format
if self._enable_concurrent_execution:
wrapper_func = self._create_tool_result_retrieval_func(pydantic_tool)
else:
wrapper_func = self._create_tool_execution_wrapper(pydantic_tool, validate_args=True)
# Wrap with the @tool decorator
sdk_tool = tool(
name=tool_name,
description=pydantic_tool.description,
input_schema=pydantic_tool.function_schema.json_schema,
)(wrapper_func)
sdk_tools.append(sdk_tool)
custom_tool_names.append(f"mcp__{mcp_name}__{tool_name}")
# Create SDK MCP server with custom tools
mcp_server = create_sdk_mcp_server(
name=mcp_name,
version="1.0.0",
tools=sdk_tools,
)
mcp_servers = {mcp_name: mcp_server}
return mcp_servers, custom_tool_names
def _create_tool_execution_wrapper(
self,
pydantic_tool: Tool,
*,
validate_args: bool,
) -> Callable[[dict[str, Any]], Awaitable[ClaudeToolContent]]:
"""Create a wrapper function that converts a PydanticAI Tool to Claude Code tool format.
Args:
pydantic_tool: PydanticAI Tool instance
Returns:
An async function that accepts a dict of arguments and returns ClaudeToolContent
"""
async def normal_wrapper(args: dict[str, Any]) -> ClaudeToolContent:
with logfire.span(
f"Calling tool: {pydantic_tool.name}()",
tool_name=pydantic_tool.name,
args=args,
):
if validate_args:
# Validate arguments using the tool's schema validator
validated_args = pydantic_tool.function_schema.validator.validate_python(args)
else:
validated_args = args
result = await pydantic_tool.function_schema.call(validated_args, ctx=None)
# Convert result to Claude Code format
if isinstance(result, dict) and "content" in result:
return result # type: ignore[return-value]
else:
return {"content": [{"type": "text", "text": str(result)}]}
return normal_wrapper
def _create_tool_result_retrieval_func(
self,
pydantic_tool: Tool,
) -> Callable[[dict[str, Any]], Awaitable[ClaudeToolContent]]:
"""
Create a tool function that does not actually execute the tool but retrieves its result from the tool task.
:param pydantic_tool:
:return:
"""
# Concurrent execution wrapper - returns cached results
async def retrieve_tool_result(args: dict[str, Any]) -> ClaudeToolContent:
# Get the tool_use_id from the queue (order-based correlation)
tool_name_expected, tool_use_id = await self._tool_execution_queue.get()
with logfire.span(
f"Retrieving result for tool: {pydantic_tool.name}()",
tool_call_id=tool_use_id,
):
# Verify tool name matches
if tool_name_expected != pydantic_tool.name:
error_msg = (
f"Tool name mismatch in MCP call: expected {tool_name_expected}, got {pydantic_tool.name}"
)
logfire.error(error_msg, tool_use_id=tool_use_id)
raise RuntimeError(error_msg)
# Get cached result
if tool_use_id not in self._tool_execution_cache:
error_msg = f"No cached result found for tool_use_id {tool_use_id}"
logfire.error(error_msg, tool_name=pydantic_tool.name)
raise RuntimeError(error_msg)
# Wait for the result (may already be complete)
result = await self._tool_execution_cache[tool_use_id]
# Clean up the cache entry
del self._tool_execution_cache[tool_use_id]
return result
return retrieve_tool_result
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
"""Find a tool by its name, handling MCP prefixes.
Args:
tool_name: Tool name, possibly with MCP prefix (e.g., "mcp__builtin_tools__search")
Returns:
Tool instance if found, None otherwise
"""
# Remove MCP prefix if present (e.g., "mcp__builtin_tools__search" -> "search")
clean_name = self._get_original_tool_name_from_mcp_tool(tool_name)
for pydantic_tool in self._tools:
if pydantic_tool.name == clean_name:
return pydantic_tool
return None
async def _execute_tool_concurrently(
self,
tool: Tool,
tool_args: dict[str, Any],
) -> ClaudeToolContent:
"""Execute a tool asynchronously so its result can be retrieved later.
Args:
tool: PydanticAI Tool instance to execute
tool_args: Validated input arguments for the tool
Returns:
Tool execution result in Claude Code format
"""
wrapper = self._create_tool_execution_wrapper(tool, validate_args=False)
formatted_result = await wrapper(tool_args)
return formatted_result
async def _execute_concurrent_mcp_tool(self, tool_block: ToolUseBlock) -> None:
"""
Launch async execution for a single tool use block.
If tool name or arguments are invalid, do not create a task since the MCP client should also fail validation
and not actually make the tool call.
Args:
tool_block: ToolUseBlock to execute concurrently
"""
# Find the tool
tool = self._find_tool_by_name(tool_block.name)
if not tool:
logfire.warn(f"Invalid tool name: '{tool_block.name}'", tool_use_id=tool_block.id)
return
try:
validated_args = tool.function_schema.validator.validate_python(tool_block.input)
except ValidationError as e:
logfire.warn(f"Invalid arguments for tool '{tool_block.name}'", tool_use_id=tool_block.id, error=e)
return
logfire.debug(
"Launching concurrent execution for tool",
tool_use_id=tool_block.id,
tool_name=tool_block.name,
)
# Create async task for this tool execution with pre-validated args
task = asyncio.create_task(
self._execute_tool_concurrently(
tool=tool,
tool_args=validated_args,
)
)
# Cache the task (future) by tool_use_id
self._tool_execution_cache[tool_block.id] = task
# Add to queue for order-based correlation
# Store (tool_name, tool_use_id) so we can match MCP calls
await self._tool_execution_queue.put(
(self._get_original_tool_name_from_mcp_tool(tool_block.name), tool_block.id)
)
async def run(self, user_query: str) -> ClaudeCodeResult:
"""Execute a query and return a single result.
This method waits for the complete response and returns a consolidated
result containing the final text content and metadata.
Args:
user_query: The user's query/prompt to send to Claude Code
Returns:
ClaudeCodeResult containing the response text, result metadata,
and session ID for resuming
"""
result_message = None
# Collect all messages. The stream automatically terminates after ResultMessage.
async for message in self.stream(user_query):
# There may be intermediate AssistantMessages before the final ResultMessage that wont be captured
if isinstance(message, ResultMessage):
result_message = message
# Continue to let iterator finish naturally and ensure cleanup
if not result_message or result_message.result is None:
raise RuntimeError("No ResultMessage received from Claude Code")
return ClaudeCodeResult(
text_content=result_message.result,
result_message=result_message,
session_id=result_message.session_id,
)
async def stream(self, user_query: str) -> AsyncIterator[Message]:
"""Execute a query and stream individual messages as they arrive.
This method yields messages in real-time, allowing for progressive
rendering and processing of the response.
Args:
user_query: The user's query/prompt to send to Claude Code
Yields:
Message objects (UserMessage, AssistantMessage, SystemMessage, ResultMessage)
as they are received from Claude Code
"""
with logfire.span(
"claude_client.stream",
session_id=self.session_id,
model=self.options.model,
system_prompt=self.options.system_prompt,
):
async with ClaudeSDKClient(options=self.options) as client:
# Send the query
with logfire.span("claude_client.send_query", query=user_query):
await client.query(user_query)
async for message in client.receive_response():
message_desc = describe_message(message)
logfire.info(f"Received message: {message_desc}", message=message)
await self._start_running_any_mcp_tools(message)
yield message
async def _start_running_any_mcp_tools(self, message: Message):
# Launch concurrent tool executions if enabled
if self._enable_concurrent_execution and isinstance(message, AssistantMessage):
for block in message.content:
if isinstance(block, ToolUseBlock) and self._is_mcp_tool(block.name):
await self._execute_concurrent_mcp_tool(block)
@staticmethod
def _is_mcp_tool(tool_name: str) -> bool:
return "__" in tool_name and tool_name.split("__")[0] == "mcp"
@staticmethod
def _get_original_tool_name_from_mcp_tool(mcp_tool_name: str) -> str:
"""Get the original tool name from an MCP tool name."""
return mcp_tool_name.split("__")[-1] if "__" in mcp_tool_name else mcp_tool_name
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment