|
#!/usr/bin/env python3 |
|
""" |
|
Fixed @tool decorator implementation that properly respects Optional type annotations. |
|
|
|
This implementation checks both default values AND type annotations to determine |
|
if parameters are required, fixing the bug where Optional types were incorrectly |
|
marked as required. |
|
""" |
|
|
|
import inspect |
|
import json |
|
import sys |
|
from typing import Any, Dict, List, Optional, Union, get_origin, get_args |
|
from functools import wraps |
|
|
|
|
|
def is_optional_type(annotation: Any) -> bool: |
|
""" |
|
Check if a type annotation represents an Optional type. |
|
|
|
Handles multiple forms of optional types: |
|
- typing.Optional[T] (which is Union[T, None]) |
|
- typing.Union[T, None] |
|
- T | None (Python 3.10+) |
|
|
|
Args: |
|
annotation: The type annotation to check |
|
|
|
Returns: |
|
True if the annotation represents an optional type, False otherwise |
|
""" |
|
# Handle None type directly |
|
if annotation is type(None): |
|
return True |
|
|
|
# Handle string annotations (forward references) |
|
if isinstance(annotation, str): |
|
# Simple heuristics for string annotations |
|
return ( |
|
annotation == "None" or |
|
annotation.startswith("Optional[") or |
|
" | None" in annotation or |
|
"Union[" in annotation and ", None" in annotation |
|
) |
|
|
|
# Get origin and args for generic types |
|
origin = get_origin(annotation) |
|
args = get_args(annotation) |
|
|
|
# Check for Union types (Optional[T] is Union[T, None]) |
|
if origin is Union: |
|
# Check if None is one of the union members |
|
return type(None) in args |
|
|
|
# Python 3.10+ union syntax (T | None) |
|
if sys.version_info >= (3, 10): |
|
# Handle new union syntax using types.UnionType |
|
if hasattr(annotation, '__class__') and annotation.__class__.__name__ == 'UnionType': |
|
return type(None) in annotation.__args__ |
|
|
|
return False |
|
|
|
|
|
def extract_parameter_info(signature: inspect.Signature) -> Dict[str, Dict[str, Any]]: |
|
""" |
|
Extract parameter information from a function signature. |
|
|
|
Args: |
|
signature: Function signature to analyze |
|
|
|
Returns: |
|
Dictionary mapping parameter names to their info including: |
|
- type: parameter type annotation |
|
- required: whether the parameter is required |
|
- has_default: whether the parameter has a default value |
|
- default: the default value (if any) |
|
""" |
|
params = {} |
|
|
|
for name, param in signature.parameters.items(): |
|
# Skip *args and **kwargs |
|
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD): |
|
continue |
|
|
|
has_default = param.default != inspect.Parameter.empty |
|
is_optional = is_optional_type(param.annotation) |
|
|
|
# A parameter is required if: |
|
# 1. It has no default value AND |
|
# 2. It's not marked as Optional in type annotations |
|
required = not has_default and not is_optional |
|
|
|
params[name] = { |
|
'type': param.annotation, |
|
'required': required, |
|
'has_default': has_default, |
|
'default': param.default if has_default else inspect.Parameter.empty, |
|
'is_optional_type': is_optional |
|
} |
|
|
|
return params |
|
|
|
|
|
def create_json_schema(func) -> Dict[str, Any]: |
|
""" |
|
Create a JSON schema for a function that properly handles Optional types. |
|
|
|
Args: |
|
func: The function to create a schema for |
|
|
|
Returns: |
|
JSON schema dictionary with correct required fields |
|
""" |
|
signature = inspect.signature(func) |
|
param_info = extract_parameter_info(signature) |
|
|
|
# Extract required parameters (those that are both not optional and have no default) |
|
required_params = [ |
|
name for name, info in param_info.items() |
|
if info['required'] |
|
] |
|
|
|
# Build properties for the schema |
|
properties = {} |
|
for name, info in param_info.items(): |
|
# Convert Python types to JSON schema types (simplified) |
|
param_type = "string" # Default fallback |
|
|
|
if info['type'] == int or info['type'] == Optional[int]: |
|
param_type = "integer" |
|
elif info['type'] == float or info['type'] == Optional[float]: |
|
param_type = "number" |
|
elif info['type'] == bool or info['type'] == Optional[bool]: |
|
param_type = "boolean" |
|
elif info['type'] == list or info['type'] == Optional[list]: |
|
param_type = "array" |
|
elif info['type'] == dict or info['type'] == Optional[dict]: |
|
param_type = "object" |
|
|
|
properties[name] = { |
|
"type": param_type, |
|
"description": f"Parameter {name}" |
|
} |
|
|
|
return { |
|
"type": "object", |
|
"properties": properties, |
|
"required": required_params |
|
} |
|
|
|
|
|
def tool(func): |
|
""" |
|
Fixed @tool decorator that properly respects Optional type annotations. |
|
|
|
This decorator analyzes function signatures and correctly identifies optional |
|
parameters based on both default values AND type annotations. |
|
""" |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
return func(*args, **kwargs) |
|
|
|
# Create the tool specification with proper Optional handling |
|
schema = create_json_schema(func) |
|
|
|
# Store the schema on the wrapped function |
|
wrapper.tool_spec = schema |
|
wrapper._original_func = func |
|
|
|
return wrapper |
|
|
|
|
|
# Test functions to verify the fix |
|
@tool |
|
def test_optional_without_default_fixed(param1: str, param2: Optional[int]) -> dict: |
|
"""Test function with Optional parameter but no default value - should be optional.""" |
|
return {"param1": param1, "param2": param2} |
|
|
|
|
|
@tool |
|
def test_union_syntax_fixed(param1: str, param2: int | None) -> dict: |
|
"""Test function with Union syntax - should be optional.""" |
|
return {"param1": param1, "param2": param2} |
|
|
|
|
|
@tool |
|
def test_optional_with_default_fixed(param1: str, param2: Optional[int] = None) -> dict: |
|
"""Test function with Optional parameter and default - should be optional.""" |
|
return {"param1": param1, "param2": param2} |
|
|
|
|
|
@tool |
|
def test_mixed_params_fixed( |
|
required_str: str, |
|
required_int: int, |
|
optional_no_default: Optional[str], |
|
optional_with_default: Optional[int] = 42, |
|
union_syntax: str | None = None |
|
) -> dict: |
|
"""Test function with mixed required and optional parameters.""" |
|
return { |
|
"required_str": required_str, |
|
"required_int": required_int, |
|
"optional_no_default": optional_no_default, |
|
"optional_with_default": optional_with_default, |
|
"union_syntax": union_syntax |
|
} |
|
|
|
|
|
def test_fix(): |
|
"""Test the fixed @tool decorator implementation.""" |
|
print("=== Fixed @tool Decorator Test Results ===") |
|
print() |
|
|
|
test_functions = [ |
|
test_optional_without_default_fixed, |
|
test_union_syntax_fixed, |
|
test_optional_with_default_fixed, |
|
test_mixed_params_fixed |
|
] |
|
|
|
for func in test_functions: |
|
print(f"Function: {func.__name__}") |
|
print(f"Signature: {inspect.signature(func._original_func)}") |
|
|
|
schema = func.tool_spec |
|
print(f"Required parameters: {schema['required']}") |
|
print(f"All parameters: {list(schema['properties'].keys())}") |
|
print(f"Schema: {json.dumps(schema, indent=2)}") |
|
print("-" * 60) |
|
|
|
|
|
def demonstrate_type_detection(): |
|
"""Demonstrate the type detection logic.""" |
|
print("=== Optional Type Detection Tests ===") |
|
print() |
|
|
|
test_cases = [ |
|
(Optional[int], True, "Optional[int]"), |
|
(Union[int, None], True, "Union[int, None]"), |
|
(int, False, "int"), |
|
(str, False, "str"), |
|
(Union[str, int], False, "Union[str, int] (no None)"), |
|
("Optional[str]", True, "String annotation: Optional[str]"), |
|
("str | None", True, "String annotation: str | None"), |
|
("int", False, "String annotation: int") |
|
] |
|
|
|
for annotation, expected, description in test_cases: |
|
result = is_optional_type(annotation) |
|
status = "✓" if result == expected else "✗" |
|
print(f"{status} {description}: {result} (expected: {expected})") |
|
|
|
print() |
|
|
|
|
|
if __name__ == "__main__": |
|
demonstrate_type_detection() |
|
test_fix() |