Skip to content

Instantly share code, notes, and snippets.

@aviadr1
Last active November 17, 2025 07:02
Show Gist options
  • Select an option

  • Save aviadr1/2d1186625d67fba9c8f421d273bf7a53 to your computer and use it in GitHub Desktop.

Select an option

Save aviadr1/2d1186625d67fba9c8f421d273bf7a53 to your computer and use it in GitHub Desktop.
**OpenAI JSON Schema Sanitizer for Pydantic Models** - A production-ready function that transforms any Pydantic model into an OpenAI Structured Outputs-compatible JSON schema, handling optionals, unions, recursion detection, numeric constraints, and additionalProperties issues that cause API failures. Includes comprehensive test suite covering a…
"""
File: gv/ai/common/llm/json_schema.py
Author: Aviad Rozenhek
OpenAI Structured Outputs (`response_format={"type":"json_schema"}`) supports only a subset of JSON Schema.
Many perfectly valid Pydantic constructs won't fly as-is. Use these patterns:
-------------------------------------------------------------------------
1) Optional / nullable / Default fields
-------------------------------------------------------------------------
❌ Pydantic:
class M(BaseModel):
title: Optional[str] = None # -> type: ["string", "null"]
✅ OpenAI-friendly:
# Make it required, represent "no value" explicitly (e.g., empty string).
class MDTO(BaseModel):
title: str # required in schema; post-parse, treat "" as "missing"
# OR split presence into a separate boolean or sentinel:
class MDTO(BaseModel):
has_title: bool
title: str # required; if has_title is False, expect "" and fix app-side
-------------------------------------------------------------------------
2) Unions / anyOf / oneOf / allOf
-------------------------------------------------------------------------
❌ Pydantic:
class M(BaseModel):
score: int | float # -> anyOf: [{"type":"integer"},{"type":"number"}]
payload: User | Org # -> oneOf/allOf + $ref
✅ OpenAI-friendly:
# Pick a single, widest type and normalize app-side.
class MDTO(BaseModel):
score: float
payload_kind: Literal["user","org"]
payload_user: UserDTO # keep one branch
# (or flatten the shared fields; resolve shape app-side)
-------------------------------------------------------------------------
3) Numeric bounds (ge / le / conint / confloat)
-------------------------------------------------------------------------
❌ Pydantic:
class M(BaseModel):
frac: condecimal(ge=0, le=1)
rating: float = Field(ge=0, le=5)
✅ OpenAI-friendly:
class MDTO(BaseModel):
frac: float # add "0–1" note in description; clamp/validate after parsing
rating: float
# The sanitizer moves min/max into the description and strips the keys.
# Enforce ranges in Python after you get the model's output.
-------------------------------------------------------------------------
4) Enums
-------------------------------------------------------------------------
✅ Supported *if* they are simple string/number enums **inline**:
class MDTO(BaseModel):
status: Literal["OPEN","CLOSED","PENDING"]
⚠️ Avoid deep `$ref` indirection for enums if possible; inline them (the sanitizer inlines local $refs).
-------------------------------------------------------------------------
5) Additional properties (free-form maps)
-------------------------------------------------------------------------
✅ Strict by default (recommended for LLMs):
# Sanitizer can set: "additionalProperties": false (opt-in flag)
# This prevents the model from inventing fields.
✅ If you truly need a bag/dict:
class MDTO(BaseModel):
metadata: Dict[str, Any] # allow extras under this object
# Use sanitizer options/exceptions to keep additionalProperties allowed for this one field.
-------------------------------------------------------------------------
How this function helps - it transforms a schema to be OpenAI-friendly:
-----------------------
- Makes every property required (optional), avoids nullable/union traps.
- Strips defaults/bounds and moves bounds into descriptions.
- Inlines local $refs and collapses anyOf/oneOf/allOf to a single branch.
- (Optionally) sets additionalProperties:false to curb hallucinated keys.
- Preserves Pydantic field order (properties and 'required' list).
- ERRORS on recursive models (OpenAI doesn't support them)
"""
from copy import deepcopy
from typing import Any, Dict, get_args, get_origin
from loguru import logger
from pydantic import BaseModel
DEBUG = 0
if DEBUG:
logger_trace = logger.debug
else:
logger_trace = lambda *args, **kwargs: None # noqa: E731
# ----------------------------- Preflight helpers -----------------------------
def _is_optional_annotation(tp) -> bool:
"""
True for Optional[T] or (T | None) on all Python versions.
Works for typing.Union and PEP 604 unions (types.UnionType).
"""
try:
return type(None) in get_args(tp)
except Exception:
return False
def _iter_submodels_from_annotation(tp, visited=None):
"""
Yield BaseModel subclasses found inside typing annotations (recurse Lists, Unions, etc.).
"""
if visited is None:
visited = set()
if tp is None:
return
# Prevent infinite recursion for self-referential types
type_id = id(tp)
if type_id in visited:
return
visited.add(type_id)
origin = get_origin(tp)
if origin is None:
# bare type
if isinstance(tp, type) and issubclass(tp, BaseModel):
yield tp
return
args = get_args(tp)
for a in args:
if isinstance(a, type) and issubclass(a, BaseModel):
yield a
else:
yield from _iter_submodels_from_annotation(a, visited)
def _check_for_recursion(model_class, visited=None, path=None):
"""
Check if a model contains recursive references.
Raises ValueError if recursion is detected.
"""
if visited is None:
visited = set()
if path is None:
path = []
model_id = id(model_class)
if model_id in visited:
# Found recursion - build error message
cycle_path = " -> ".join(path + [model_class.__name__])
raise ValueError(
f"Recursive model detected: {model_class.__name__}\n"
f"Recursion path: {cycle_path}\n"
f"OpenAI's structured outputs do NOT support recursive schemas or $ref.\n"
f"Consider redesigning your model to avoid recursion, such as:\n"
f" - Using a flat structure with IDs to reference related objects\n"
f" - Limiting nesting to a fixed depth with different model types\n"
f" - Using a string field to store serialized nested data"
)
visited.add(model_id)
new_path = path + [model_class.__name__]
for field_name, field_info in model_class.model_fields.items():
annotation = field_info.annotation
# Check if the field references other BaseModel subclasses
for submodel in _iter_submodels_from_annotation(annotation):
_check_for_recursion(submodel, visited.copy(), new_path)
def _preflight_model_consistency(
model: type[BaseModel],
*,
errors: list[str],
warnings: list[str],
path: str = "",
warn_on_nullable_description_mismatch: bool = True,
visited: set = None,
) -> None:
"""
Collect preflight issues that cannot be fully 'fixed' by schema sanitization alone.
- ERROR: Non-optional type with default=None
- WARN : 'if ... then ... else' description (OpenAI subset can't encode)
- WARN : description mentions 'null/none' but field is non-optional (likely non-nullable in strict schema)
Recurse into nested BaseModels discovered via annotations.
"""
if visited is None:
visited = set()
# Prevent infinite recursion on self-referential models
if model in visited:
return
visited.add(model)
prefix = f"{path}." if path else ""
for name, f in model.model_fields.items():
ann = getattr(f, "annotation", None)
default = getattr(f, "default", ...)
desc = (getattr(f, "description", "") or "").lower()
# 1) ERROR: Non-optional annotation but default=None (Pydantic will reject None at runtime)
if default is None and ann is not None and not _is_optional_annotation(ann):
errors.append(
f"{prefix}{model.__name__}.{name}: non-optional type {getattr(ann, '__name__', ann)} "
f"has default=None (unsatisfiable without changing model or post-processing)."
)
# 2) WARN: if/then/else prose requirements
if " if " in f" {desc} " and " then " in f" {desc} " and " else " in f" {desc} ":
warnings.append(
f"{prefix}{model.__name__}.{name}: description encodes conditional requirement "
f"('if/then/else'); OpenAI strict JSON Schema cannot express that."
)
# 3) WARN: 'null' or 'none' mentioned in description but field is non-optional
if warn_on_nullable_description_mismatch and ((" null" in f" {desc}") or (" none" in f" {desc}")):
if ann is not None and not _is_optional_annotation(ann):
warnings.append(
f"{prefix}{model.__name__}.{name}: description mentions null/none but annotation is non-optional; "
f"field will be non-nullable in strict schema."
)
# Recurse into submodels referenced by this field's type
for sub in _iter_submodels_from_annotation(ann):
_preflight_model_consistency(
sub,
errors=errors,
warnings=warnings,
path=f"{prefix}{model.__name__}.{name}",
warn_on_nullable_description_mismatch=warn_on_nullable_description_mismatch,
visited=visited,
)
# --------------------------- Schema sanitizer core ---------------------------
def sanitize_for_openai_schema(
model: type[BaseModel],
*,
require_all: bool = False,
make_optionals_required_nullable: bool = True,
strip_defaults: bool = True,
strip_numeric_bounds: bool = True,
move_bounds_to_description: bool = True,
inline_refs: bool = True,
disallow_additional_props: bool = True,
# NEW: diagnostics
error_on_preflight: bool = True,
warn_on_nullable_description_mismatch: bool = True,
# NEW: behavior for enum unions
merge_enums_even_with_string: bool = True,
) -> Dict[str, Any]:
"""
Produce an OpenAI-structured-outputs-friendly JSON Schema from a Pydantic model.
Enhancements:
- Preflight: errors/warnings for unsatisfiable or risky patterns (see parameters).
- Optionals → required + nullable (OpenAI strict style).
- Union collapse: merges enum branches; if a 'string' branch exists and merge_enums_even_with_string=True,
we still merge enums and discard the free-form string branch (preserving nullability), and annotate.
- Inlines local $refs and collapses anyOf/oneOf/allOf to a single branch.
- (Optionally) sets additionalProperties:false to curb hallucinated keys.
- Preserves Pydantic field order (properties and 'required' list).
- ERRORS on recursive models (OpenAI doesn't support them)
"""
# Check for recursive models FIRST - fail fast
try:
_check_for_recursion(model)
except ValueError as e:
logger.error(f"Schema validation failed: {e}")
raise
# -------- Preflight consistency checks --------
errors: list[str] = []
warnings: list[str] = []
_preflight_model_consistency(
model,
errors=errors,
warnings=warnings,
warn_on_nullable_description_mismatch=warn_on_nullable_description_mismatch,
)
for w in warnings:
logger.warning(f"[schema preflight] {w}")
if errors:
for e in errors:
logger.error(f"[schema preflight] {e}")
if error_on_preflight:
raise ValueError("Schema/model consistency issues:\n- " + "\n- ".join(errors))
# -------- Generate raw schema and sanitize --------
raw = model.model_json_schema()
schema = deepcopy(raw)
defs = schema.get("$defs") or schema.get("definitions") or {}
def _append_desc(node: Dict[str, Any], extra: str) -> None:
if not extra:
return
if "description" in node and isinstance(node["description"], str):
node["description"] = node["description"].rstrip() + f"\n\n{extra}"
else:
node["description"] = extra
def _strip_bounds_and_move_description(node: Dict[str, Any]) -> None:
if not strip_numeric_bounds:
return
string_constraints = []
numeric_constraints = []
# Strip string constraints (pattern, minLength, maxLength)
for k in ("pattern", "minLength", "maxLength"):
if k in node:
val = node.pop(k)
string_constraints.append(f"{k}={val}")
# Numeric constraints
for k in ("minimum", "exclusiveMinimum", "maximum", "exclusiveMaximum"):
if k in node:
val = node.pop(k)
numeric_constraints.append(f"{k.replace('exclusive', 'exclusive ')}={val}")
# Add descriptions for different types of constraints
if move_bounds_to_description:
if string_constraints:
_append_desc(node, "String constraints (enforced app-side): " + ", ".join(string_constraints) + ".")
if numeric_constraints:
_append_desc(node, "Numeric constraints (enforced app-side): " + ", ".join(numeric_constraints) + ".")
def _ensure_nullable_type(node: Dict[str, Any]) -> None:
"""
Ensure node['type'] includes 'null' (do NOT add null into 'enum').
"""
t = node.get("type")
if "enum" in node and not t:
# infer primitive type for enum
if all(isinstance(v, (int, float)) for v in node["enum"]):
node["type"] = "number"
else:
node["type"] = "string"
t = node["type"]
if isinstance(t, list):
if "null" not in t:
node["type"] = t + ["null"]
elif isinstance(t, str):
if t != "null":
node["type"] = [t, "null"]
else:
node["type"] = ["string", "null"]
else:
node["type"] = ["string", "null"]
def _get_enum_type(enum_values):
"""
Determine the type of enum values.
"""
if all(isinstance(v, bool) for v in enum_values):
return "boolean"
elif all(isinstance(v, int) and not isinstance(v, bool) for v in enum_values):
return "integer"
elif all(isinstance(v, (int, float)) and not isinstance(v, bool) for v in enum_values):
return "number"
else:
return "string"
def _inline_ref(node: Dict[str, Any], visited_refs=None) -> Dict[str, Any]:
if visited_refs is None:
visited_refs = set()
ref = node.get("$ref")
if not ref or not inline_refs:
return node
# We should never hit recursive refs because we checked earlier
# But if we somehow do, error out
if ref in visited_refs:
raise ValueError(
f"Unexpected recursive reference found during inlining: {ref}\n"
f"This should have been caught earlier. The model likely has recursive structure "
f"which is not supported by OpenAI's structured outputs."
)
if not ref.startswith("#/"):
node.pop("$ref", None) # external refs unsupported
return node
visited_refs_copy = visited_refs.copy()
visited_refs_copy.add(ref)
parts = ref.lstrip("#/").split("/")
cur: Any = schema
for p in parts:
if isinstance(cur, dict) and p in cur:
cur = cur[p]
else:
cur = defs.get(p, {})
inlined = deepcopy(cur) if isinstance(cur, dict) else {}
# Preserve any additional properties from the original node (like description)
for key in node:
if key != "$ref" and key not in inlined:
inlined[key] = node[key]
return _walk(inlined, visited_refs=visited_refs_copy)
def _collapse_alternatives(node: Dict[str, Any]) -> None:
"""
Handle anyOf/oneOf/allOf:
- Merge enum branches into a single enum (even if a 'string' branch is present, when enabled).
- Preserve nullability if any alt is {"type":"null"}.
- Otherwise keep first alternative and annotate.
"""
for key in ("anyOf", "oneOf", "allOf"):
alts = node.get(key)
if not (isinstance(alts, list) and alts):
continue
merged_enums: list[Any] = []
enum_type: str | None = None
saw_null = False
saw_string_branch = False
resolved: list[Dict[str, Any]] = []
def _resolve(a: Dict[str, Any]) -> Dict[str, Any]:
return _inline_ref(a) if (isinstance(a, dict) and "$ref" in a) else a
for a in alts:
a = a if isinstance(a, dict) else {}
a = _resolve(a)
resolved.append(a)
t = a.get("type")
if t == "null":
saw_null = True
continue
if "enum" in a:
current_enum_type = _get_enum_type(a["enum"])
# Only merge enums if they're compatible types or if we haven't set a type yet
if not enum_type:
enum_type = current_enum_type
merged_enums.extend(a["enum"])
elif enum_type == current_enum_type:
merged_enums.extend(a["enum"])
elif enum_type in ["integer", "number"] and current_enum_type in ["integer", "number"]:
# Allow merging integer and number enums, use number as the type
enum_type = "number"
merged_enums.extend(a["enum"])
# Otherwise skip incompatible enum types
continue
if t == "string":
saw_string_branch = True
# non-enum branch remains
node.pop(key, None)
if merged_enums:
node.pop("$ref", None)
# If there's a string branch and the toggle is on, we still merge enums and discard free-form string
node["type"] = enum_type or "string"
# Don't sort enums, preserve order and uniqueness
seen = set()
unique_vals = []
for v in merged_enums:
if v not in seen:
seen.add(v)
unique_vals.append(v)
node["enum"] = unique_vals
if saw_null:
_ensure_nullable_type(node)
if saw_string_branch and merge_enums_even_with_string:
_append_desc(
node,
"Note: union included a free-form string branch; sanitizer merged enum branches "
"and discarded the open string to keep values constrained.",
)
return
# Fallback: keep first resolved alt (and preserve nullability if present)
first = resolved[0] if resolved else {}
for k, v in first.items():
if k not in node:
node[k] = v
if any(a.get("type") == "null" for a in resolved):
_ensure_nullable_type(node)
_append_desc(
node,
f"Note: original schema had `{key}` with {len(alts)} alternatives; "
f"sanitizer kept the first and preserved nullability.",
)
return
# Track warnings for "null/none mentioned but field is non-nullable" during walking
def _walk(node: Any, path: str = "<root>", visited_refs=None) -> Any:
if visited_refs is None:
visited_refs = set()
logger_trace(f"[ENTER] {path}: dict={isinstance(node, dict)}")
if isinstance(node, dict):
if "$ref" in node and inline_refs:
logger_trace(f"[{path}] Resolving $ref: {node['$ref']}")
return _inline_ref(node, visited_refs)
# Strip defaults / format / examples etc.
if strip_defaults and "default" in node:
node.pop("default", None)
if "format" in node:
node.pop("format", None)
for k in ("examples", "example", "readOnly", "writeOnly", "deprecated"):
node.pop(k, None)
# Convert "const" to "enum" for compatibility
if "const" in node:
const_value = node.pop("const")
node["enum"] = [const_value]
# Ensure type is set if not already
if "type" not in node:
if isinstance(const_value, bool):
node["type"] = "boolean"
elif isinstance(const_value, int):
node["type"] = "integer"
elif isinstance(const_value, float):
node["type"] = "number"
else:
node["type"] = "string"
_collapse_alternatives(node) # may rewrite node in place
_strip_bounds_and_move_description(node)
# Check if this is an array type (including nullable arrays)
node_type = node.get("type")
is_array = False
if node_type == "array":
is_array = True
elif isinstance(node_type, list) and "array" in node_type:
is_array = True
if is_array:
if "items" in node:
node["items"] = _walk(node["items"], f"{path}.items", visited_refs)
node.setdefault("minItems", 0)
# Don't return here, continue processing
# Objects - handle both pure object and nullable object types
is_object = False
is_nullable_object = False
if node_type == "object":
is_object = True
elif isinstance(node_type, list) and "object" in node_type:
is_object = True
is_nullable_object = "null" in node_type
if is_object:
props = node.get("properties", {}) or {}
# Preserve original required (from Pydantic) to judge optionality
original_required = set(node.get("required", []))
# Walk children first
for k in list(props.keys()):
props[k] = _walk(props[k], f"{path}.properties[{k}]", visited_refs)
node["properties"] = props
# Warn if description mentions null/none for a non-optional field (simple heuristic)
if warn_on_nullable_description_mismatch and original_required:
for k in original_required:
child = props.get(k)
if not isinstance(child, dict):
continue
desc = (child.get("description") or "").lower()
if (" null" in f" {desc}") or (" none" in f" {desc}"):
# This field is non-optional in Pydantic, so we won't add nullability
logger.warning(
f"[schema warning] {path}.properties[{k}]: description mentions null/none "
f"but field is required by Pydantic (non-nullable in strict schema)."
)
# Start from Pydantic's required list
required = list(original_required)
# Emulate optionals as required+nullable
if make_optionals_required_nullable and props:
for k, p in props.items():
if k not in original_required:
required.append(k)
_ensure_nullable_type(p)
# If global require_all, override
if require_all and props:
required = list(props.keys())
if required:
ordered = [k for k in props.keys() if k in required]
node["required"] = ordered
# Handle additionalProperties - CRITICAL FIX for nullable objects
if "additionalProperties" in node:
add_props = node["additionalProperties"]
# If it's a schema (dict), ensure it has proper handling
if isinstance(add_props, dict):
# Walk the additionalProperties schema first
add_props = _walk(add_props, f"{path}.additionalProperties", visited_refs)
# For nullable object types, OpenAI requires additionalProperties: false
if is_nullable_object:
node["additionalProperties"] = False
else:
# Handle different cases for additionalProperties
if not add_props:
# Empty dict - needs a type for OpenAI
add_props = {"type": "string"}
elif "type" not in add_props and not add_props.get("$ref"):
# Has content but no type - infer or default
if "properties" in add_props:
add_props["type"] = "object"
elif "items" in add_props:
add_props["type"] = "array"
elif "enum" in add_props:
# Infer type from enum values
add_props["type"] = _get_enum_type(add_props["enum"])
else:
# Default to string for any untyped schema
add_props["type"] = "string"
node["additionalProperties"] = add_props
elif add_props is True:
# For nullable objects, must be false
if is_nullable_object:
node["additionalProperties"] = False
else:
# Convert true to a schema with type
node["additionalProperties"] = {"type": "string"}
# If false, leave as-is
elif disallow_additional_props:
# Always set to false for nullable objects
if is_nullable_object:
node["additionalProperties"] = False
elif props:
# Structured object with properties
node["additionalProperties"] = False
else:
# No properties means this is a free-form dict like Dict[str, Any]
# For non-nullable free-form dicts, allow with typed schema
node["additionalProperties"] = {"type": "string"}
return node
# Recurse other dict fields
for k, v in list(node.items()):
node[k] = _walk(v, f"{path}.{k}", visited_refs)
return node
elif isinstance(node, list):
return [_walk(v, f"{path}[{i}]", visited_refs) for i, v in enumerate(node)]
return node
sanitized = _walk(schema)
# Remove defs after inlining
sanitized.pop("$defs", None)
sanitized.pop("definitions", None)
# Top-level object finalization + root-level required+nullable catch-up
if sanitized.get("type") == "object":
props = sanitized.get("properties", {}) or {}
if require_all and props:
sanitized["required"] = list(props.keys())
else:
if make_optionals_required_nullable and props:
req = list(sanitized.get("required", []))
present = set(req)
changed = False
for k in props.keys():
if k not in present:
req.append(k)
_ensure_nullable_type(props[k])
changed = True
if changed:
sanitized["required"] = req
# Final handling of additionalProperties at root level
if "additionalProperties" in sanitized:
add_props = sanitized["additionalProperties"]
# Ensure it has proper type if it's a dict
if isinstance(add_props, dict):
if not add_props:
# Empty dict needs a type
sanitized["additionalProperties"] = {"type": "string"}
elif "type" not in add_props and not add_props.get("$ref"):
# Has content but no type
if "properties" in add_props:
add_props["type"] = "object"
else:
add_props["type"] = "string"
sanitized["additionalProperties"] = add_props
elif disallow_additional_props:
if props:
sanitized["additionalProperties"] = False
else:
# Empty model needs typed additionalProperties for OpenAI
sanitized["additionalProperties"] = {"type": "string"}
return sanitized
# -------- Usage examples --------
# 1) OpenAI SDK (Responses or Chat Completions with structured outputs)
# schema = sanitize_for_openai_schema(MessageQualityPromptSchema)
# client.responses.create(
# model="gpt-4o-mini",
# input="...", # your prompt/messages
# response_format={"type": "json_schema",
# "json_schema": {"name": "MessageQualityPromptSchema",
# "schema": schema}},
# )
# 2) LangChain: bind the raw response_format
# from langchain_openai import ChatOpenAI
# llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# schema = sanitize_for_openai_schema(MessageQualityPromptSchema)
# llm = llm.bind(response_format={"type":"json_schema",
# "json_schema":{"name":"MessageQualityPromptSchema","schema":schema}})
# result = llm.invoke("...your input...")
# tests/test_sanitizer_and_safe_structured_real.py
import enum
import os
from typing import Any, Dict, List, Literal, Optional, Union
import pytest
import pytest_asyncio
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from gv.ai.common.llm.json_schema import sanitize_for_openai_schema
from gv.ai.common.llm.llm_chain_base import safe_with_structured_output
from gv.ai.common.llm.llm_client_builder import get_langchain_model_and_token_counter
load_dotenv()
# ------------------ Small helpers for strict JSON Schema expectations ------------------
def _types(node: Dict[str, Any]) -> set[str]:
t = node.get("type")
return set(t) if isinstance(t, list) else {t} if t is not None else set()
def _assert_base_type(node: Dict[str, Any], expected: str):
# Accept required+nullable by ignoring "null"
assert expected in (_types(node) - {"null"}), f"expected base type {expected}, got {_types(node)}"
def _assert_nullable(node: Dict[str, Any], should_be_nullable: bool = True):
has_null = "null" in _types(node)
assert has_null if should_be_nullable else not has_null
# ------------------ Models ------------------
class MyEnum(str, enum.Enum):
"""
Example enum with two values.
"""
ALPHA = "alpha"
BETA = "beta"
class InnerModel(BaseModel):
x: int = Field(..., description="X coordinate in integer form.")
y: float = Field(..., ge=0.0, le=1.0, description="Y coordinate (normalized between 0.0 and 1.0).")
class BadModelWithEnumAndInner(BaseModel):
a: Optional[str] = Field(None, description="Optional text field to test nullable handling.")
e: MyEnum = Field(..., description="An enum field that should retain allowed values.")
inner: InnerModel = Field(..., description="A nested InnerModel with numeric constraints.")
c: int | float = Field(0, description="A number that could be int or float (union).")
action: Literal["X", "Y", "Z"] = Field("X", description="A literal value that must be one of: 'X', 'Y', or 'Z'.")
# ------------------ Offline schema tests ------------------
def test_sanitizer_handles_enum_inner_and_bounds():
schema = sanitize_for_openai_schema(BadModelWithEnumAndInner)
props = schema["properties"]
# All fields required (optionals become required+nullable)
assert schema["required"] == ["a", "e", "inner", "c", "action"]
# Enum preserved, and since 'e' is required (no optional/default), it should NOT be nullable
assert "enum" in props["e"]
_assert_base_type(props["e"], "string")
_assert_nullable(props["e"], False)
# Inner model flattened OK, bounds removed but described
inner_props = props["inner"]["properties"]
assert list(inner_props.keys()) == ["x", "y"]
assert "minimum" not in inner_props["y"]
assert "maximum" not in inner_props["y"]
# FIX: Check for "constraints" (case-insensitive) instead of "numeric constraints"
desc_lower = inner_props["y"]["description"].lower()
assert "constraints" in desc_lower, f"Expected 'constraints' in description: {inner_props['y']['description']}"
# ------------------ Real OpenAI invocation test ------------------
need_openai = pytest.mark.skipif(
not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set; skipping real OpenAI test."
)
@pytest_asyncio.fixture(scope="module")
@need_openai
async def real_chain():
"""
Create a real OpenAI chain using safe_with_structured_output.
"""
model, _ = get_langchain_model_and_token_counter("gpt-4o-mini")
chain = safe_with_structured_output(model, BadModelWithEnumAndInner, prefer_json_schema=True)
return chain
@pytest.mark.integration_test
@pytest.mark.asyncio
@need_openai
async def test_real_openai_invocation(real_chain):
"""
Sends a prompt that should produce a valid BadModelWithEnumAndInner and ensures the output parses.
"""
prompt = (
"Return a JSON object matching the following:\n"
"- a: some string\n"
"- e: 'beta'\n"
"- inner: {x: 42, y: 0.5}\n"
"- c: 3.14\n"
"- action: 'Y'\n"
)
out = await real_chain.ainvoke(prompt)
assert isinstance(out, BadModelWithEnumAndInner)
assert out.a
assert out.e == MyEnum.BETA
assert out.inner.x == 42
assert 0 <= out.inner.y <= 1
assert isinstance(out.c, (int, float))
assert out.action in {"X", "Y", "Z"}
# ------------------ Models for edge cases ------------------
class TitleEnum(str, enum.Enum):
"""
Short enum summary for the model (title helps the LLM).
"""
ONE = "one"
TWO = "two"
class InnerEdge(BaseModel):
first: int = Field(..., description="First comes first.")
second: float = Field(0.5, ge=0.0, le=1.0, description="Second is a normalized float. Defaults handled app-side.")
class EdgeCaseModel(BaseModel):
# Optionals in every flavor
opt_none_str: Optional[str] = Field(None, description="Optional string, may be null.")
opt_with_default_num: Optional[int] = Field(5, description="Optional int with non-None default (5).")
opt_factory_list: Optional[List[int]] = Field(
default_factory=list, description="Optional list with default_factory."
)
# Defaults of various types
d_str: str = Field("", description="Empty string default (apply app-side).")
d_int: int = Field(0, description="Zero default (apply app-side).")
d_float: float = Field(0.0, description="Zero float default (apply app-side).")
d_bool: bool = Field(False, description="False default (apply app-side).")
d_list: List[str] = Field(default_factory=list, description="List default via factory.")
d_dict: Dict[str, Any] = Field(default_factory=dict, description="Dict default via factory.")
# Enums and Literals
e1: TitleEnum = Field(TitleEnum.ONE, description="Enum field.")
e2: Literal["A", "B", "C"] = Field("A", description="Literal field with three values.")
# Nested & unions
inner: InnerEdge = Field(..., description="Nested model with constraints.")
union_scalar: Union[int, float] = Field(1, description="Union of number types (collapses to a single type).")
# Free-form maps + typed maps
metadata_loose: Dict[str, Any] = Field(default_factory=dict, description="Free-form map, typed 'any'.")
metadata_typed: Dict[str, int] = Field(default_factory=dict, description="Typed map of string->int.")
# Arrays with constraints (Pydantic emits some)
tags: List[str] = Field(default_factory=list, description="An array of tags.")
# ------------------ Sanitizer: Field order & required ------------------
def test_field_order_and_required_preserved_default():
sch = sanitize_for_openai_schema(EdgeCaseModel)
props = sch["properties"]
assert list(props.keys()) == [
"opt_none_str",
"opt_with_default_num",
"opt_factory_list",
"d_str",
"d_int",
"d_float",
"d_bool",
"d_list",
"d_dict",
"e1",
"e2",
"inner",
"union_scalar",
"metadata_loose",
"metadata_typed",
"tags",
], "Property order must match class declaration"
assert sch["required"] == list(props.keys()), "All required in same order"
def test_nested_field_order_and_required_preserved():
sch = sanitize_for_openai_schema(EdgeCaseModel)
inner = sch["properties"]["inner"]
assert inner["type"] == "object"
assert list(inner["properties"].keys()) == ["first", "second"]
assert inner["required"] == ["first", "second"]
# ------------------ Optionals & nullability notes ------------------
def test_optionals_become_required_and_nullable_noted():
sch = sanitize_for_openai_schema(EdgeCaseModel)
p = sch["properties"]["opt_none_str"]
# Required at schema-level
assert "opt_none_str" in sch["required"]
# Type should be required+nullable (OpenAI strict); base type string
_assert_base_type(p, "string")
_assert_nullable(p, True)
# (Optional) description guidance is nice-to-have but not strictly required by the sanitizer
def test_optionals_with_non_none_defaults_and_factory():
sch = sanitize_for_openai_schema(EdgeCaseModel)
p_num = sch["properties"]["opt_with_default_num"]
p_list = sch["properties"]["opt_factory_list"]
# No 'default' should remain
assert "default" not in p_num and "default" not in p_list
# Required+nullable with correct base types
_assert_base_type(p_num, "integer") # or "number" - integer is fine since it's a subset
_assert_nullable(p_num, True)
_assert_base_type(p_list, "array")
_assert_nullable(p_list, True)
# ------------------ Defaults of all kinds removed ------------------
@pytest.mark.parametrize("name", ["d_str", "d_int", "d_float", "d_bool", "d_list", "d_dict"])
def test_defaults_removed_for_various_types(name):
sch = sanitize_for_openai_schema(EdgeCaseModel)
assert "default" not in sch["properties"][name], f"default must be stripped for {name}"
# ------------------ Enums & titles kept, literals ok ------------------
def test_enums_and_literals_inline_and_titles_present():
sch = sanitize_for_openai_schema(EdgeCaseModel)
e1 = sch["properties"]["e1"]
e2 = sch["properties"]["e2"]
# These have defaults in Pydantic → optional → required+nullable in strict mode
_assert_base_type(e1, "string")
_assert_nullable(e1, True)
assert "enum" in e1
_assert_base_type(e2, "string")
_assert_nullable(e2, True)
assert "enum" in e2
# Titles are produced by Pydantic - sanitizer should not remove them
assert "title" in e1 and "title" in e2
# ------------------ Numeric bounds stripped & noted ------------------
def test_numeric_bounds_stripped_and_documented():
sch = sanitize_for_openai_schema(EdgeCaseModel)
inner_y = sch["properties"]["inner"]["properties"]["second"]
assert "minimum" not in inner_y and "maximum" not in inner_y
# FIX: Check for "constraints" (case-insensitive) instead of "numeric constraints"
desc_lower = inner_y.get("description", "").lower()
assert "constraints" in desc_lower, f"Expected 'constraints' in description: {inner_y.get('description', '')}"
# ------------------ Union collapse ------------------
def test_union_collapse_scalar():
sch = sanitize_for_openai_schema(EdgeCaseModel)
u = sch["properties"]["union_scalar"]
assert "anyOf" not in u and "oneOf" not in u and "allOf" not in u
assert "type" in u # keep a single usable type
# This field has a default (optional) → required+nullable number (int branch is fine)
_assert_base_type(u, "integer") # or "number"
_assert_nullable(u, True)
# ------------------ $ref inlining & $defs removal ------------------
def test_refs_inlined_and_defs_removed():
sch = sanitize_for_openai_schema(EdgeCaseModel)
assert "$defs" not in sch and "definitions" not in sch
# inner should be an object with properties (not a $ref)
inner = sch["properties"]["inner"]
assert inner.get("type") == "object" and "properties" in inner
# ------------------ additionalProperties handling ------------------
def test_additional_properties_false_when_enabled_top_and_nested():
sch = sanitize_for_openai_schema(EdgeCaseModel, disallow_additional_props=True)
# Top-level strict
assert sch.get("additionalProperties") is False
# Typed map → object schema with additionalProperties as a dict schema (NOT False)
typed = sch["properties"]["metadata_typed"]
_assert_base_type(typed, "object")
_assert_nullable(typed, True)
assert isinstance(typed.get("additionalProperties"), dict)
# Loose map → also a free-form object - sanitizer leaves additionalProperties as {} (dict), not False
loose = sch["properties"]["metadata_loose"]
_assert_base_type(loose, "object")
_assert_nullable(loose, True)
assert isinstance(loose.get("additionalProperties"), dict) # accept {} or a schema dict
# ------------------ Arrays: basic sanity ------------------
def test_arrays_items_sane_and_minimal():
sch = sanitize_for_openai_schema(EdgeCaseModel)
tags = sch["properties"]["tags"]
_assert_base_type(tags, "array")
_assert_nullable(tags, True) # default_factory → optional → required+nullable
assert "items" in tags and tags["items"].get("type") == "string"
assert tags.get("minItems", 0) == 0
# ------------------ Real invocation with OpenAI (skip if no key) ------------------
@pytest_asyncio.fixture(scope="module")
@pytest.mark.timeout(10)
@need_openai
async def real_chain_edge():
model, _ = get_langchain_model_and_token_counter("gpt-4o-mini")
# Use json_schema path
chain = safe_with_structured_output(model, EdgeCaseModel, prefer_json_schema=True)
return chain
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_edge(real_chain_edge):
prompt = (
"Return JSON with all fields:\n"
"- opt_none_str: 'hello'\n"
"- opt_with_default_num: 9\n"
"- opt_factory_list: [1,2,3]\n"
"- d_str: 's'\n"
"- d_int: 1\n"
"- d_float: 0.25\n"
"- d_bool: true\n"
"- d_list: ['a','b']\n"
"- d_dict: {k: 'v'}\n"
"- e1: 'two'\n"
"- e2: 'B'\n"
"- inner: {first: 7, second: 0.7}\n"
"- union_scalar: 3.14\n"
"- metadata_loose: {anyKey: 123}\n"
"- metadata_typed: {count: 2}\n"
"- tags: ['x','y']\n"
)
out = await real_chain_edge.ainvoke(prompt)
assert isinstance(out, EdgeCaseModel)
assert out.opt_none_str == "hello"
assert out.e1 == TitleEnum.TWO
assert out.inner.first == 7 and 0 <= out.inner.second <= 1
assert isinstance(out.union_scalar, (int, float))
assert isinstance(out.metadata_loose, dict)
assert isinstance(out.metadata_typed.get("count"), int)
assert out.tags == ["x", "y"]
"""
Additional test cases for edge cases and OpenAI JSON Schema requirements that might be missing from your current test
suite.
"""
import json
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
import pytest
from pydantic import (
BaseModel,
Field,
computed_field,
confloat,
conint,
constr,
create_model,
field_validator,
model_validator,
)
from gv.ai.common.llm.llm_chain_base import safe_with_structured_output
from gv.ai.common.llm.llm_client_builder import get_langchain_model_and_token_counter
from .test_json_schema import need_openai
INTEGRATION_MODELS = ["gpt-4.1-nano", "gpt-5-nano"]
@pytest.fixture
def sanitizer():
"""
Fixture that returns the sanitizer function.
"""
def _sanitize(model, **kwargs):
from gv.ai.common.llm.json_schema import sanitize_for_openai_schema
return sanitize_for_openai_schema(
model,
require_all=False,
make_optionals_required_nullable=True,
inline_refs=True,
disallow_additional_props=True,
**kwargs,
)
return _sanitize
# ============== 1. ENUM EDGE CASES ==============
class MixedTypeEnum(str, Enum):
"""
OpenAI requires consistent enum types.
"""
STRING_VAL = "text"
# This would be problematic if we had mixed types:
# NUMBER_VAL = 123 # Can't mix str and int in same enum!
class IntEnum(int, Enum):
"""
Integer enums need special handling.
"""
POSITIVE = 1
NEGATIVE = -1
ZERO = 0
class FloatEnum(float, Enum):
"""
Float enums are tricky.
"""
PI = 3.14159
E = 2.71828
class EnumEdgeCases(BaseModel):
status: MixedTypeEnum
score: IntEnum
ratio: Optional[FloatEnum] = None
def test_enum_type_consistency(sanitizer):
"""
OpenAI requires all enum values to be same type.
"""
schema = sanitizer(EnumEdgeCases)
# Integer enum should have "type": "integer" or ["integer", ...]
score = schema["properties"]["score"]
assert "integer" in score.get("type", []) or score.get("type") == "integer"
# Float enum should be "number"
ratio = schema["properties"]["ratio"]
types = ratio.get("type", [])
assert "number" in types or types == "number"
# ============== 2. LITERAL TYPES ==============
class LiteralModel(BaseModel):
# Literals are like single-value enums
version: Literal["v1"]
mode: Literal["production", "development"]
# Complex literal union
status: Optional[Literal["active", "inactive", "pending"]] = None
def test_literals_become_enums_simple(sanitizer):
"""
Literals should convert to enum in schema.
"""
schema = sanitizer(LiteralModel)
# Check what Pydantic actually generates for Literal types
version_field = schema["properties"]["version"]
# Pydantic might not always convert Literal to enum, it depends on the version
# Let's check what we actually get
if "enum" in version_field:
assert version_field["enum"] == ["v1"]
elif "const" in version_field:
# Pydantic might use "const" for single-value Literals
assert version_field["const"] == "v1"
else:
# If neither enum nor const, check the type at least
assert version_field.get("type") == "string"
# Log what we actually got for debugging
print(f"Version field schema: {version_field}")
# For multiple literals, Pydantic should create an enum
mode_field = schema["properties"]["mode"]
if "enum" in mode_field:
assert set(mode_field["enum"]) == {"production", "development"}
# Optional literals should be nullable
status_field = schema["properties"]["status"]
if "enum" in status_field:
assert set(status_field["enum"]) >= {"active", "inactive", "pending"}
assert "null" in status_field.get("type", [])
def test_literals_become_enums(sanitizer):
"""
Literals should convert to enum in schema.
Pydantic uses 'const' for single-value Literals and 'enum' for multi-value. Our sanitizer should convert 'const' to
'enum' for consistency.
"""
schema = sanitizer(LiteralModel)
# Single literal: Pydantic generates "const", sanitizer should convert to "enum"
version_field = schema["properties"]["version"]
assert "enum" in version_field, f"Expected enum, got: {version_field}"
assert version_field["enum"] == ["v1"]
assert version_field.get("type") == "string"
# Multiple literals: Pydantic generates "enum" directly
mode_field = schema["properties"]["mode"]
assert "enum" in mode_field
assert set(mode_field["enum"]) == {"production", "development"}
assert mode_field.get("type") == "string"
# Optional literals: Pydantic generates anyOf[enum, null]
# After sanitization, should be a single enum with nullable type
status_field = schema["properties"]["status"]
assert "enum" in status_field
assert set(status_field["enum"]) == {"active", "inactive", "pending"}
# Should be nullable since it's optional
assert isinstance(status_field.get("type"), list)
assert "string" in status_field["type"]
assert "null" in status_field["type"]
# ============== 3. NESTED OPTIONALS & DEEP STRUCTURES ==============
class DeeplyNested(BaseModel):
level1: Optional[Dict[str, Optional[List[Optional[str]]]]] = None
# This is pathological but legal in Pydantic!
class RecursiveModel(BaseModel):
"""
Self-referential models need special care.
"""
name: str
children: Optional[List["RecursiveModel"]] = None
# Update forward refs
RecursiveModel.model_rebuild()
class TestRecursiveModelCases:
"""
Test cases for recursive model handling.
"""
def test_recursive_models_are_rejected(self, sanitizer):
"""
Recursive models should be rejected with a clear error message since OpenAI doesn't support them.
"""
with pytest.raises(ValueError) as exc_info:
sanitizer(RecursiveModel)
# Check that the error message is helpful
error_msg = str(exc_info.value)
assert "Recursive model detected" in error_msg
assert "RecursiveModel" in error_msg
assert "OpenAI's structured outputs do NOT support recursive schemas" in error_msg
assert "Consider redesigning your model" in error_msg
def test_recursive_model_error_provides_alternatives(self, sanitizer):
"""
When rejecting recursive models, the error should suggest alternatives.
"""
with pytest.raises(ValueError) as exc_info:
sanitizer(RecursiveModel)
error_msg = str(exc_info.value)
# Check that alternatives are suggested
assert "flat structure with IDs" in error_msg
assert "fixed depth" in error_msg
assert "serialized nested data" in error_msg
def test_deeply_nested_but_not_recursive_is_allowed(self, sanitizer):
"""
Models that are deeply nested but not recursive should work fine.
"""
# Create a deeply nested but non-recursive model
class Level3(BaseModel):
value: str
class Level2(BaseModel):
level3: Level3
data: str
class Level1(BaseModel):
level2: Level2
name: str
# This should work fine - no recursion
schema = sanitizer(Level1)
# Verify the schema was generated
assert schema["type"] == "object"
assert "properties" in schema
assert "level2" in schema["properties"]
# Check that nesting is properly handled
level2_schema = schema["properties"]["level2"]
assert "properties" in level2_schema
assert "level3" in level2_schema["properties"]
def test_mutual_recursion_is_detected(self, sanitizer):
"""
Mutual recursion (A -> B -> A) should also be detected and rejected.
"""
class NodeA(BaseModel):
name: str
b_node: Optional["NodeB"] = None
class NodeB(BaseModel):
name: str
a_node: Optional[NodeA] = None
# Update forward refs
NodeA.model_rebuild()
NodeB.model_rebuild()
# Both should be rejected
with pytest.raises(ValueError) as exc_info:
sanitizer(NodeA)
assert "Recursive model detected" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
sanitizer(NodeB)
assert "Recursive model detected" in str(exc_info.value)
def test_self_referential_list_is_detected(self, sanitizer):
"""
Self-referential models through lists should be detected.
"""
class TreeNode(BaseModel):
value: int
children: List["TreeNode"] = []
TreeNode.model_rebuild()
with pytest.raises(ValueError) as exc_info:
sanitizer(TreeNode)
error_msg = str(exc_info.value)
assert "Recursive model detected" in error_msg
assert "TreeNode" in error_msg
def test_optional_self_reference_is_detected(self, sanitizer):
"""
Even optional self-references should be detected.
"""
class LinkedListNode(BaseModel):
data: str
next: Optional["LinkedListNode"] = None
LinkedListNode.model_rebuild()
with pytest.raises(ValueError) as exc_info:
sanitizer(LinkedListNode)
error_msg = str(exc_info.value)
assert "Recursive model detected" in error_msg
assert "LinkedListNode" in error_msg
def test_recursion_through_union_is_detected(self, sanitizer):
"""
Recursion through Union types should be detected.
"""
class UnionNode(BaseModel):
value: str
child: Union[str, "UnionNode", None] = None
UnionNode.model_rebuild()
with pytest.raises(ValueError) as exc_info:
sanitizer(UnionNode)
error_msg = str(exc_info.value)
assert "Recursive model detected" in error_msg
def test_non_recursive_complex_model_works(self, sanitizer):
"""
Complex models that look recursive but aren't should work.
"""
class Address(BaseModel):
street: str
city: str
class Person(BaseModel):
name: str
address: Address
class Company(BaseModel):
name: str
employees: List[Person]
headquarters: Address
# This should work - no actual recursion
schema = sanitizer(Company)
assert schema["type"] == "object"
assert "employees" in schema["properties"]
assert "headquarters" in schema["properties"]
# Check that the list of persons is properly structured
employees = schema["properties"]["employees"]
assert employees["type"] == "array"
assert "items" in employees
assert employees["items"]["type"] == "object"
assert "name" in employees["items"]["properties"]
assert "address" in employees["items"]["properties"]
# ============== 4. DISCRIMINATED UNIONS (Tagged Unions) ==============
class Cat(BaseModel):
pet_type: Literal["cat"]
meows: int
class Dog(BaseModel):
pet_type: Literal["dog"]
barks: int
class DiscriminatedUnionModel(BaseModel):
# Pydantic supports discriminated unions, OpenAI doesn't
pet: Union[Cat, Dog] = Field(..., discriminator="pet_type")
def test_discriminated_unions_collapse(sanitizer):
"""
Discriminated unions must collapse to single schema.
"""
schema = sanitizer(DiscriminatedUnionModel)
# Should not have anyOf/oneOf
pet = schema["properties"]["pet"]
assert "anyOf" not in pet
assert "oneOf" not in pet
# Should collapse to first type or merge fields
assert pet["type"] == "object"
# ============== 5. CONSTRAINED TYPES ==============
class ConstrainedModel(BaseModel):
# String constraints
username: constr(min_length=3, max_length=20, pattern=r"^[a-zA-Z0-9_]+$")
email: constr(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")
# Numeric constraints
age: conint(ge=0, le=150)
percentage: confloat(ge=0.0, le=100.0)
# Array constraints
tags: List[str] = Field(..., min_items=1, max_items=10)
def test_constraints_moved_to_description(sanitizer):
"""
All constraints should move to description, not schema.
"""
schema = sanitizer(ConstrainedModel)
username = schema["properties"]["username"]
# Pattern should be stripped and moved to description
assert "pattern" not in username
assert "minLength" not in username
assert "maxLength" not in username
# But info should be in description
assert (
"min_length" in username.get("description", "").lower()
or "constraints" in username.get("description", "").lower()
)
age = schema["properties"]["age"]
assert "minimum" not in age
assert "maximum" not in age
# ============== 6. DATETIME AND SPECIAL TYPES ==============
class DateTimeModel(BaseModel):
created_at: datetime
birthday: Optional[date] = None
price: Decimal
metadata: Dict[str, Any] # Free-form dict
raw_json: Optional[Dict[Any, Any]] = None
def test_datetime_handling(sanitizer):
"""
DateTime/date should become strings with format info.
"""
schema = sanitizer(DateTimeModel)
# DateTime becomes string
created = schema["properties"]["created_at"]
assert created["type"] == "string"
# Format might be stripped, check description
assert "format" not in created or created["format"] == "date-time"
# Decimal becomes number
price = schema["properties"]["price"]
assert price["type"] == "number"
# Free-form dicts should keep additionalProperties open
metadata = schema["properties"]["metadata"]
assert metadata["type"] == "object"
# This one specifically should NOT have additionalProperties: false
# since it's meant to be a free-form dict
# ============== 7. EMPTY MODELS & EDGE CASES ==============
class EmptyModel(BaseModel):
"""
Model with no fields.
"""
pass
class AllOptionalModel(BaseModel):
"""
Every field is optional.
"""
a: Optional[str] = None
b: Optional[int] = None
c: Optional[bool] = None
def test_empty_and_all_optional_models(sanitizer):
"""
Edge cases with empty or all-optional models.
"""
# Empty model
empty_schema = sanitizer(EmptyModel)
assert empty_schema["type"] == "object"
assert empty_schema["properties"] == {}
assert empty_schema.get("required", []) == []
# All optional -> all required+nullable
optional_schema = sanitizer(AllOptionalModel)
assert set(optional_schema["required"]) == {"a", "b", "c"}
for field in ["a", "b", "c"]:
assert "null" in optional_schema["properties"][field]["type"]
# ============== 8. COMPLEX UNION COLLAPSES ==============
class ComplexUnionModel(BaseModel):
# Union of different types (not just enums)
value: Union[str, int, List[str], Dict[str, int]]
# Union with None in middle (not at end)
weird: Union[str, None, int] = None
# Multiple enums + types
mixed: Union[MixedTypeEnum, IntEnum, str, int] = Field(...)
def test_complex_union_handling(sanitizer):
"""
Complex unions should collapse sensibly.
"""
schema = sanitizer(ComplexUnionModel)
# These should all collapse to first type or have a strategy
value = schema["properties"]["value"]
assert "anyOf" not in value
assert "oneOf" not in value
# Should pick a type (probably string as most general)
assert value.get("type") is not None
# ============== 9. DEFAULT VALUES EDGE CASES ==============
class DefaultsModel(BaseModel):
# Non-None defaults on optional fields
name: Optional[str] = "Anonymous"
count: Optional[int] = 0
# Callable defaults
created: datetime = Field(default_factory=datetime.now)
# Mutable defaults (bad practice but legal)
tags: List[str] = Field(default_factory=list)
def test_defaults_handling(sanitizer):
"""
Various default value patterns.
"""
schema = sanitizer(DefaultsModel)
# Defaults should be stripped
for prop in schema["properties"].values():
assert "default" not in prop
# But fields should still be required (OpenAI style)
assert "name" in schema["required"]
assert "count" in schema["required"]
# ============== 10. VALIDATION THAT BREAKS STRUCTURED OUTPUT ==============
class ProblematicModel(BaseModel):
# This has a validator that depends on another field
password: str
confirm_password: str
# Pydantic v2 style validator
@field_validator("confirm_password")
@classmethod
def passwords_match(cls, v, info):
if "password" in info.data and v != info.data["password"]:
raise ValueError("passwords do not match")
return v
# Computed field (Pydantic v2 style)
@computed_field
@property
def password_strength(self) -> str:
return "weak" if len(self.password) < 8 else "strong"
# Alternative: Using model_validator for cross-field validation
class ProblematicModelAlt(BaseModel):
password: str
confirm_password: str
@model_validator(mode="after")
def check_passwords_match(self):
if self.password != self.confirm_password:
raise ValueError("passwords do not match")
return self
def test_validators_dont_affect_schema(sanitizer):
"""
Validators and computed fields shouldn't affect schema.
"""
schema = sanitizer(ProblematicModel)
# Should only have actual fields
assert set(schema["properties"].keys()) == {"password", "confirm_password"}
# Computed property should not appear
assert "password_strength" not in schema["properties"]
# ============== 11. PERFORMANCE & SCALE ==============
def test_large_model_performance(sanitizer):
"""
Test with a model that has many fields.
"""
# Create a model with 100 fields dynamically
fields = {f"field_{i}": (Optional[str], None) for i in range(100)}
LargeModel = create_model("LargeModel", **fields)
import time
start = time.time()
schema = sanitizer(LargeModel)
duration = time.time() - start
# Should complete in reasonable time
assert duration < 1.0 # Less than 1 second for 100 fields
# All fields should be required
assert len(schema["required"]) == 100
# ============== 12. BROKEN MODELS THAT SHOULD ERROR ==============
class UnsatisfiableModel(BaseModel):
# Required field with None default (your preflight check catches this!)
required_field: str = None # This is broken!
# Union of incompatible types that can't be resolved
impossible: Union[type, object, type(None)]
def test_unsatisfiable_models_raise_errors(sanitizer):
"""
Models that can't be satisfied should raise clear errors.
"""
with pytest.raises(ValueError) as exc_info:
sanitizer(UnsatisfiableModel, error_on_preflight=True)
assert "unsatisfiable" in str(exc_info.value).lower()
class LiteralTestModel(BaseModel):
single_literal: Literal["v1"]
multi_literal: Literal["production", "development"]
optional_literal: Optional[Literal["active", "inactive", "pending"]] = None
literal_with_field: Literal["test"] = Field(description="A literal field")
def test_literals():
# Generate and inspect the schema
model = LiteralTestModel
schema = model.model_json_schema()
print("Raw Pydantic schema for Literal types:")
print(json.dumps(schema, indent=2))
# Check what's in each field
for field_name, field_info in model.model_fields.items():
print(f"\n{field_name}:")
field_schema = schema["properties"].get(field_name, {})
if "$ref" in field_schema:
# Follow the reference
ref_path = field_schema["$ref"].split("/")[-1]
if "$defs" in schema and ref_path in schema["$defs"]:
field_schema = schema["$defs"][ref_path]
print(f" Referenced schema: {field_schema}")
else:
print(f" Direct schema: {field_schema}")
# Check for enum
if "enum" in field_schema:
print(f" Has enum: {field_schema['enum']}")
elif "const" in field_schema:
print(f" Has const: {field_schema['const']}")
elif "anyOf" in field_schema or "allOf" in field_schema:
print(f" Has union: {field_schema}")
else:
print(f" Type only: {field_schema.get('type', 'no type')}")
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_EnumEdgeCases(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, EnumEdgeCases, prefer_json_schema=True)
prompt = "Return JSON with:\n- status: 'text'\n- score: -1 (must be 1, -1, or 0)\n- ratio: 3.14159 (or null)\n"
result = await chain.ainvoke(prompt)
assert isinstance(result, EnumEdgeCases)
assert result.status == MixedTypeEnum.STRING_VAL
assert result.score in [IntEnum.POSITIVE, IntEnum.NEGATIVE, IntEnum.ZERO]
assert result.ratio in [FloatEnum.PI, FloatEnum.E, None]
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_LiteralModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, LiteralModel, prefer_json_schema=True)
prompt = "Return JSON with:\n- version: 'v1'\n- mode: 'production'\n- status: 'active' (or null)\n"
result = await chain.ainvoke(prompt)
assert isinstance(result, LiteralModel)
assert result.version == "v1"
assert result.mode in ["production", "development"]
assert result.status in ["active", "inactive", "pending", None]
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_DeeplyNested(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, DeeplyNested, prefer_json_schema=True)
prompt = "Return JSON with:\n- level1: {'key1': ['value1', null, 'value3'], 'key2': null} (or entirely null)\n"
result = await chain.ainvoke(prompt)
assert isinstance(result, DeeplyNested)
assert isinstance(result.level1, (dict, type(None)))
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_DiscriminatedUnionModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, DiscriminatedUnionModel, prefer_json_schema=True)
prompt = "Return JSON with:\n- pet: {pet_type: 'cat', meows: 5}\n"
result = await chain.ainvoke(prompt)
assert isinstance(result, DiscriminatedUnionModel)
assert isinstance(result.pet, (Cat, Dog))
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_ConstrainedModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, ConstrainedModel, prefer_json_schema=True)
prompt = (
"Return JSON with:\n"
"- username: 'john_doe123'\n"
"- email: 'john@example.com'\n"
"- age: 25\n"
"- percentage: 75.5\n"
"- tags: ['python', 'ai', 'testing']\n"
)
result = await chain.ainvoke(prompt)
assert isinstance(result, ConstrainedModel)
assert isinstance(result.username, str)
assert len(result.username) >= 3 and len(result.username) <= 20
assert isinstance(result.email, str)
assert isinstance(result.age, int)
assert 0 <= result.age <= 150
assert isinstance(result.percentage, float)
assert 0.0 <= result.percentage <= 100.0
assert isinstance(result.tags, list)
assert 1 <= len(result.tags) <= 10
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_DateTimeModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, DateTimeModel, prefer_json_schema=True)
prompt = (
"Return JSON with:\n"
"- created_at: '2024-01-15T10:30:00Z'\n"
"- birthday: '1990-05-20' (or null)\n"
"- price: 99.99\n"
"- metadata: {'version': '1.0', 'author': 'test'}\n"
"- raw_json: {'any': 'data', 'nested': {'key': 'value'}} (or null)\n"
)
result = await chain.ainvoke(prompt)
assert isinstance(result, DateTimeModel)
assert isinstance(result.created_at, datetime)
assert isinstance(result.birthday, (date, type(None)))
assert isinstance(result.price, Decimal)
assert isinstance(result.metadata, dict)
assert isinstance(result.raw_json, (dict, type(None)))
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_EmptyModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, EmptyModel, prefer_json_schema=True)
prompt = "Return an empty JSON object: {}"
result = await chain.ainvoke(prompt)
assert isinstance(result, EmptyModel)
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_AllOptionalModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, AllOptionalModel, prefer_json_schema=True)
prompt = "Return JSON with:\n- a: 'text' (or null)\n- b: 42 (or null)\n- c: true (or null)\n"
result = await chain.ainvoke(prompt)
assert isinstance(result, AllOptionalModel)
assert isinstance(result.a, (str, type(None)))
assert isinstance(result.b, (int, type(None)))
assert isinstance(result.c, (bool, type(None)))
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_ComplexUnionModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, ComplexUnionModel, prefer_json_schema=True)
prompt = (
"Return JSON with:\n- value: 'simple string'\n- weird: 42 (or null)\n- mixed: 'text' (from MixedTypeEnum)\n"
)
result = await chain.ainvoke(prompt)
assert isinstance(result, ComplexUnionModel)
assert result.value is not None
assert isinstance(result.weird, (str, int, type(None)))
assert result.mixed is not None
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_DefaultsModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, DefaultsModel, prefer_json_schema=True)
prompt = (
"Return JSON with:\n"
"- name: 'Alice' (or null)\n"
"- count: 10 (or null)\n"
"- created: '2024-01-15T12:00:00Z'\n"
"- tags: ['tag1', 'tag2']\n"
)
result = await chain.ainvoke(prompt)
assert isinstance(result, DefaultsModel)
assert isinstance(result.name, (str, type(None)))
assert isinstance(result.count, (int, type(None)))
assert isinstance(result.created, datetime)
assert isinstance(result.tags, list)
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_ProblematicModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, ProblematicModel, prefer_json_schema=True)
prompt = "Return JSON with:\n- password: 'secret123'\n- confirm_password: 'secret123'\n"
result = await chain.ainvoke(prompt)
assert isinstance(result, ProblematicModel)
assert isinstance(result.password, str)
assert isinstance(result.confirm_password, str)
assert result.password == result.confirm_password
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_ProblematicModelAlt(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, ProblematicModelAlt, prefer_json_schema=True)
prompt = "Return JSON with:\n- password: 'mypass456'\n- confirm_password: 'mypass456'\n"
result = await chain.ainvoke(prompt)
assert isinstance(result, ProblematicModelAlt)
assert isinstance(result.password, str)
assert isinstance(result.confirm_password, str)
assert result.password == result.confirm_password
@pytest.mark.integration_test
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", INTEGRATION_MODELS)
@pytest.mark.timeout(10)
@need_openai
async def test_real_openai_invocation_LiteralTestModel(model_name):
model, _ = get_langchain_model_and_token_counter(model_name)
chain = safe_with_structured_output(model, LiteralTestModel, prefer_json_schema=True)
prompt = (
"Return JSON with:\n"
"- single_literal: 'v1'\n"
"- multi_literal: 'development'\n"
"- optional_literal: 'inactive' (or null)\n"
"- literal_with_field: 'test'\n"
)
result = await chain.ainvoke(prompt)
assert isinstance(result, LiteralTestModel)
assert result.single_literal == "v1"
assert result.multi_literal in ["production", "development"]
assert result.optional_literal in ["active", "inactive", "pending", None]
assert result.literal_with_field == "test"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment