Skip to content

Instantly share code, notes, and snippets.

@intellectronica
Last active December 26, 2024 19:08
Show Gist options
  • Select an option

  • Save intellectronica/9b190aca94bf4372c4b08e8b016922ec to your computer and use it in GitHub Desktop.

Select an option

Save intellectronica/9b190aca94bf4372c4b08e8b016922ec to your computer and use it in GitHub Desktop.
pydantic-ai-openai-strict.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"authorship_tag": "ABX9TyP3kCA/NmPkzL7PebGZ35z1",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/intellectronica/9b190aca94bf4372c4b08e8b016922ec/pydantic-ai-openai-strict.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tCSjwdW6sbf1"
},
"outputs": [],
"source": [
"%pip install pydantic_ai asyncpg opentelemetry-instrumentation-asyncpg logfire[asyncpg] nest_asyncio\n",
"from IPython.display import clear_output ; clear_output()"
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"from google.colab import userdata\n",
"\n",
"for secret in ['OPENAI_API_KEY', 'LOGFIRE_TOKEN']:\n",
" os.environ[secret] = userdata.get(secret)"
],
"metadata": {
"id": "02-od_5vsiOD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply()"
],
"metadata": {
"id": "Jk0utnZitSKG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from openai import AsyncOpenAI\n",
"import logfire\n",
"\n",
"openai_client = AsyncOpenAI()\n",
"\n",
"_ = logfire.configure(console=False)\n",
"_ = logfire.instrument_openai(openai_client)\n",
"_ = logfire.instrument_asyncpg()"
],
"metadata": {
"id": "cY3Q1t9qt_pf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"##########################################################################\n",
"# Hack to get Pydantic AI using OpenAI structured outputs in strict mode #\n",
"##########################################################################\n",
"\n",
"from pydantic_ai.models.openai import OpenAIModel\n",
"from pydantic_ai.tools import ToolDefinition\n",
"from openai.types.chat import ChatCompletionToolParam\n",
"\n",
"class StrictOpenAIModel(OpenAIModel):\n",
" \"\"\"OpenAIModel with strict mode enabled.\n",
"\n",
" This class can be used instead of OpenAIModel to enable strict mode\n",
" for all tool calls, including any typed results.\n",
" \"\"\"\n",
" @staticmethod\n",
" def _map_tool_definition(f: ToolDefinition) -> ChatCompletionToolParam:\n",
" \"\"\"Redefinition of _map_tool_definition to enable strict mode.\n",
"\n",
" Function calls can use strict mode (guaraneeing adherence to the\n",
" defined schema) by specifying `strict: true` for the function, and by\n",
" ensuring that the parameters for the function explicitly include\n",
" `additionalProperties: false`.\n",
" \"\"\"\n",
" tool_def = OpenAIModel._map_tool_definition(f)\n",
" tool_def['function']['strict'] = True\n",
" tool_def['function']['parameters']['additionalProperties'] = False\n",
" return tool_def"
],
"metadata": {
"id": "QOXKqis10o2R"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from textwrap import wrap\n",
"\n",
"from pydantic import BaseModel, Field\n",
"from pydantic_ai import Agent\n",
"\n",
"\n",
"class ElaborateResponse(BaseModel):\n",
" thoughts: str = Field(..., description=(\n",
" \"Step-by-step thought process and reasoning. \"\n",
" \"Approximately 300 tokens.\"),\n",
" )\n",
" answer: str = Field(\n",
" ..., description=\"Final answer to the question.\",\n",
" )\n",
"\n",
"openai_model = StrictOpenAIModel('gpt-4o', openai_client=openai_client)\n",
"agent = Agent(openai_model, result_type=ElaborateResponse)\n",
"result = await agent.run(\"How many 'r's are in the word strawberry?\")\n",
"\n",
"print('Answer:', result.data.answer)\n",
"print('\\n---\\n')\n",
"print('\\n'.join(wrap(result.data.thoughts)))"
],
"metadata": {
"id": "vYws9wUGuD-R"
},
"execution_count": null,
"outputs": []
}
]
}
@samuelcolvin
Copy link

Nice! You can probably simplify _map_tool_definition to

def _map_tool_definition(f: ToolDefinition) -> ChatCompletionToolParam:
    tool_param = super()._map_tool_definition(f)
    tool_param['function']['strict'] = True
    return tool_param

@intellectronica
Copy link
Author

Nice! You can probably simplify _map_tool_definition to

def _map_tool_definition(f: ToolDefinition) -> ChatCompletionToolParam:
    tool_param = super()._map_tool_definition(f)
    tool_param['function']['strict'] = True
    return tool_param

@samuelcolvin Almost. super() can't be used like this, but we can just call the method from the class explicitly. And we need to also set additionalProperties.

This works:

  @staticmethod
  def _map_tool_definition(f: ToolDefinition) -> ChatCompletionToolParam:
      tool_def = OpenAIModel._map_tool_definition(f)
      tool_def['function']['strict'] = True
      tool_def['function']['parameters']['additionalProperties'] = False
      return tool_def

@intellectronica
Copy link
Author

Gist updated. Thanks for the suggestion, this is more elegant and less duplicative.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment