Skip to content

Instantly share code, notes, and snippets.

@ergo70
Created August 20, 2025 10:39
Show Gist options
  • Select an option

  • Save ergo70/2999f4578765a4304d951db9ce191acd to your computer and use it in GitHub Desktop.

Select an option

Save ergo70/2999f4578765a4304d951db9ce191acd to your computer and use it in GitHub Desktop.
Symbolic math tool calling with Pydantic AI and SymPy
import re
from __future__ import annotations
from typing import Dict
from pydantic_ai import Agent
from sympy import symbols, sympify, diff
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
def _normalize_expr_text(expr_text: str) -> str:
"""Lightweight normalizer so inputs like "f(x)=x^2." become "x**2" for SymPy.
- Keeps only the RHS if there's an equals sign
- Strips trailing punctuation
- Converts caret-power to Python power (**)
"""
s = expr_text.strip()
if "=" in s:
# take the right-hand side only
s = s.split("=", 1)[1]
# remove f(x) prefix remnants like "f(x) ="
s = re.sub(r"^[a-zA-Z_]\w*\s*\(.*?\)\s*", "", s).strip()
s = s.rstrip(".;: ")
s = s.replace("^", "**")
return s
ollama_model = OpenAIModel(
model_name="""qwen2.5:32b""", provider=OpenAIProvider(base_url="""http://127.0.0.1:11434/v1""")
)
# --- Define the agent ---
agent = Agent(
ollama_model,
system_prompt=(
"You are a calculus helper. For any differentiation request, you MUST call the "
"`sympy_derivative` tool rather than doing math yourself.\n"
"In the final answer: write plain text using '^' for exponents (e.g. x^2) and '*' for multiplication.\n"
"If the user defines f(x)=..., mirror that style in the result as f'(x)=...\n"
),
)
@agent.tool_plain
def sympy_derivative(expression: str, var: str = "x", order: int = 1) -> Dict[str, str]:
"""Compute the n-th derivative of an expression using SymPy.
Args:
expression: The function expression (e.g. "x^2", "f(x)=x^3 + 2*x").
var: Variable to differentiate with respect to (default: "x").
order: Derivative order (>=1), default 1.
Returns:
A JSON object with keys:
- derivative: the derivative as a SymPy string (e.g. "2*x")
- normalized_expression: normalized input expression (e.g. "x**2")
- var: the variable used
- order: the order used (as string)
"""
print(
f"Calculating {order}-th derivative of '{expression}' with respect to '{var}'...")
var_sym = symbols(var)
rhs = _normalize_expr_text(expression)
expr = sympify(rhs, locals={var: var_sym})
deriv = diff(expr, var_sym, int(order))
return {
"derivative": str(deriv),
"normalized_expression": str(expr),
"var": var,
"order": str(order),
}
if __name__ == "__main__":
user_prompt = "Please give me the first derivative of f(x)=x^3"
run = agent.run_sync(user_prompt)
print(run.output)
# Expected: "The first derivative of f(x)=x^3 is f'(x)=3*x^2."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment