Last active
September 14, 2025 14:53
-
-
Save dead8309/b276aca04c6f8450bd0b3e597a0c9f9e to your computer and use it in GitHub Desktop.
scrappy vercel/aisdk impl in python with generics
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, Callable | |
| from pydantic import BaseModel, Field | |
| class Tool(BaseModel): | |
| name: str | |
| description: str | |
| parameters: dict[str, Any] | |
| fn: Callable | |
| async def execute(self, parameters: dict[str, Any]) -> dict: | |
| return await self.fn(parameters) | |
| class WeatherParams(BaseModel): | |
| latitude: str = Field(..., description="latitude of place") | |
| longitude: str = Field(..., description="longitude of a place") | |
| class WeatherResult(BaseModel): | |
| temperature: float = Field(..., description="current temperature of place") | |
| async def weather_callback(parameters: dict) -> WeatherResult: | |
| # NOTE: validate and fetch info... | |
| p = WeatherParams.model_validate(parameters) | |
| return WeatherResult(temperature=310) | |
| weather_tool = Tool( | |
| name="get_weather", | |
| description="Gets a weather from some altitudes", | |
| parameters=WeatherParams.model_json_schema(), | |
| fn=weather_callback, | |
| ) | |
| async def main(): | |
| lat_long = WeatherParams(latitude="2.97194", longitude="77.59369") | |
| r = await weather_tool.execute(lat_long.model_dump()) | |
| print(r.temperature) | |
| s = weather_tool.execute({"anything": 123}) # WARNING: This should not work |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import inspect | |
| from typing import Any, Awaitable, Callable, Generic, TypeVar, cast, get_type_hints | |
| from pydantic import BaseModel, Field | |
| T = TypeVar("T", bound=BaseModel) | |
| R = TypeVar("R", bound=BaseModel) | |
| class Tool(BaseModel, Generic[T, R]): | |
| name: str = Field(...) | |
| description: str = Field(...) | |
| fn: Callable[[T], Awaitable[R]] = Field(..., exclude=True) | |
| p_schema: type[T] = Field(exclude=True) | |
| async def execute(self, parameters: T) -> R: | |
| return await self.fn(parameters) | |
| @property | |
| def parameters(self) -> dict[str, Any]: | |
| return self.p_schema.model_json_schema() | |
| def tool( | |
| name: str, | |
| description: str, | |
| ) -> Callable[[Callable[[T], Awaitable[R]]], Tool[T, R]]: | |
| def wrapper(fn: Callable[[T], Awaitable[R]]) -> Tool[T, R]: | |
| sig = inspect.signature(fn) | |
| params = list(sig.parameters.values()) | |
| if len(params) != 1: | |
| raise TypeError("must take exactly one parameter") | |
| hints = get_type_hints(fn) | |
| param_name = next(iter(sig.parameters)) | |
| p_T = hints[param_name] | |
| if not (isinstance(p_T, type) and issubclass(p_T, BaseModel)): | |
| raise TypeError("parameter must be a pydantic model") | |
| typed_fn = cast(Callable[[T], Awaitable[R]], fn) | |
| typed_schema = cast(type[T], p_T) | |
| return Tool[T, R]( | |
| name=name, description=description, fn=typed_fn, p_schema=typed_schema | |
| ) | |
| return wrapper | |
| class WeatherParams(BaseModel): | |
| latitude: str = Field(..., description="latitude of place") | |
| longitude: str = Field(..., description="longitude of a place") | |
| class WeatherResult(BaseModel): | |
| temperature: float = Field(..., description="current temperature of place") | |
| @tool( | |
| name="get_weather", | |
| description="Gets a weather from some altitudes", | |
| ) | |
| async def WeatherTool(parameters: WeatherParams) -> WeatherResult: | |
| # NOTE: fetch info... | |
| return WeatherResult(temperature=310) | |
| class GithubUser(BaseModel): | |
| username: str | |
| user_id: str | |
| followers: int | |
| following: int | |
| class GithubParams(BaseModel): | |
| username: str | |
| @tool( | |
| name="get_github_info", | |
| description="Gets a github user information" | |
| ) | |
| async def GithubTool(p: GithubParams) -> GithubUser: | |
| return GithubUser(username=p.username,user_id="12345", followers=1, following=2) | |
| async def main(): | |
| x = await GithubTool.execute(GithubParams(username="dead8309")) | |
| print(GithubTool.parameters) | |
| if __name__ == "__main__": | |
| import asyncio | |
| asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment