Created
August 20, 2025 10:39
-
-
Save ergo70/2999f4578765a4304d951db9ce191acd to your computer and use it in GitHub Desktop.
Symbolic math tool calling with Pydantic AI and SymPy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| from __future__ import annotations | |
| from typing import Dict | |
| from pydantic_ai import Agent | |
| from sympy import symbols, sympify, diff | |
| from pydantic_ai.models.openai import OpenAIModel | |
| from pydantic_ai.providers.openai import OpenAIProvider | |
| def _normalize_expr_text(expr_text: str) -> str: | |
| """Lightweight normalizer so inputs like "f(x)=x^2." become "x**2" for SymPy. | |
| - Keeps only the RHS if there's an equals sign | |
| - Strips trailing punctuation | |
| - Converts caret-power to Python power (**) | |
| """ | |
| s = expr_text.strip() | |
| if "=" in s: | |
| # take the right-hand side only | |
| s = s.split("=", 1)[1] | |
| # remove f(x) prefix remnants like "f(x) =" | |
| s = re.sub(r"^[a-zA-Z_]\w*\s*\(.*?\)\s*", "", s).strip() | |
| s = s.rstrip(".;: ") | |
| s = s.replace("^", "**") | |
| return s | |
| ollama_model = OpenAIModel( | |
| model_name="""qwen2.5:32b""", provider=OpenAIProvider(base_url="""http://127.0.0.1:11434/v1""") | |
| ) | |
| # --- Define the agent --- | |
| agent = Agent( | |
| ollama_model, | |
| system_prompt=( | |
| "You are a calculus helper. For any differentiation request, you MUST call the " | |
| "`sympy_derivative` tool rather than doing math yourself.\n" | |
| "In the final answer: write plain text using '^' for exponents (e.g. x^2) and '*' for multiplication.\n" | |
| "If the user defines f(x)=..., mirror that style in the result as f'(x)=...\n" | |
| ), | |
| ) | |
| @agent.tool_plain | |
| def sympy_derivative(expression: str, var: str = "x", order: int = 1) -> Dict[str, str]: | |
| """Compute the n-th derivative of an expression using SymPy. | |
| Args: | |
| expression: The function expression (e.g. "x^2", "f(x)=x^3 + 2*x"). | |
| var: Variable to differentiate with respect to (default: "x"). | |
| order: Derivative order (>=1), default 1. | |
| Returns: | |
| A JSON object with keys: | |
| - derivative: the derivative as a SymPy string (e.g. "2*x") | |
| - normalized_expression: normalized input expression (e.g. "x**2") | |
| - var: the variable used | |
| - order: the order used (as string) | |
| """ | |
| print( | |
| f"Calculating {order}-th derivative of '{expression}' with respect to '{var}'...") | |
| var_sym = symbols(var) | |
| rhs = _normalize_expr_text(expression) | |
| expr = sympify(rhs, locals={var: var_sym}) | |
| deriv = diff(expr, var_sym, int(order)) | |
| return { | |
| "derivative": str(deriv), | |
| "normalized_expression": str(expr), | |
| "var": var, | |
| "order": str(order), | |
| } | |
| if __name__ == "__main__": | |
| user_prompt = "Please give me the first derivative of f(x)=x^3" | |
| run = agent.run_sync(user_prompt) | |
| print(run.output) | |
| # Expected: "The first derivative of f(x)=x^3 is f'(x)=3*x^2." |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment