Last active
November 17, 2025 07:02
-
-
Save aviadr1/2d1186625d67fba9c8f421d273bf7a53 to your computer and use it in GitHub Desktop.
**OpenAI JSON Schema Sanitizer for Pydantic Models** - A production-ready function that transforms any Pydantic model into an OpenAI Structured Outputs-compatible JSON schema, handling optionals, unions, recursion detection, numeric constraints, and additionalProperties issues that cause API failures. Includes comprehensive test suite covering a…
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| File: gv/ai/common/llm/json_schema.py | |
| Author: Aviad Rozenhek | |
| OpenAI Structured Outputs (`response_format={"type":"json_schema"}`) supports only a subset of JSON Schema. | |
| Many perfectly valid Pydantic constructs won't fly as-is. Use these patterns: | |
| ------------------------------------------------------------------------- | |
| 1) Optional / nullable / Default fields | |
| ------------------------------------------------------------------------- | |
| ❌ Pydantic: | |
| class M(BaseModel): | |
| title: Optional[str] = None # -> type: ["string", "null"] | |
| ✅ OpenAI-friendly: | |
| # Make it required, represent "no value" explicitly (e.g., empty string). | |
| class MDTO(BaseModel): | |
| title: str # required in schema; post-parse, treat "" as "missing" | |
| # OR split presence into a separate boolean or sentinel: | |
| class MDTO(BaseModel): | |
| has_title: bool | |
| title: str # required; if has_title is False, expect "" and fix app-side | |
| ------------------------------------------------------------------------- | |
| 2) Unions / anyOf / oneOf / allOf | |
| ------------------------------------------------------------------------- | |
| ❌ Pydantic: | |
| class M(BaseModel): | |
| score: int | float # -> anyOf: [{"type":"integer"},{"type":"number"}] | |
| payload: User | Org # -> oneOf/allOf + $ref | |
| ✅ OpenAI-friendly: | |
| # Pick a single, widest type and normalize app-side. | |
| class MDTO(BaseModel): | |
| score: float | |
| payload_kind: Literal["user","org"] | |
| payload_user: UserDTO # keep one branch | |
| # (or flatten the shared fields; resolve shape app-side) | |
| ------------------------------------------------------------------------- | |
| 3) Numeric bounds (ge / le / conint / confloat) | |
| ------------------------------------------------------------------------- | |
| ❌ Pydantic: | |
| class M(BaseModel): | |
| frac: condecimal(ge=0, le=1) | |
| rating: float = Field(ge=0, le=5) | |
| ✅ OpenAI-friendly: | |
| class MDTO(BaseModel): | |
| frac: float # add "0–1" note in description; clamp/validate after parsing | |
| rating: float | |
| # The sanitizer moves min/max into the description and strips the keys. | |
| # Enforce ranges in Python after you get the model's output. | |
| ------------------------------------------------------------------------- | |
| 4) Enums | |
| ------------------------------------------------------------------------- | |
| ✅ Supported *if* they are simple string/number enums **inline**: | |
| class MDTO(BaseModel): | |
| status: Literal["OPEN","CLOSED","PENDING"] | |
| ⚠️ Avoid deep `$ref` indirection for enums if possible; inline them (the sanitizer inlines local $refs). | |
| ------------------------------------------------------------------------- | |
| 5) Additional properties (free-form maps) | |
| ------------------------------------------------------------------------- | |
| ✅ Strict by default (recommended for LLMs): | |
| # Sanitizer can set: "additionalProperties": false (opt-in flag) | |
| # This prevents the model from inventing fields. | |
| ✅ If you truly need a bag/dict: | |
| class MDTO(BaseModel): | |
| metadata: Dict[str, Any] # allow extras under this object | |
| # Use sanitizer options/exceptions to keep additionalProperties allowed for this one field. | |
| ------------------------------------------------------------------------- | |
| How this function helps - it transforms a schema to be OpenAI-friendly: | |
| ----------------------- | |
| - Makes every property required (optional), avoids nullable/union traps. | |
| - Strips defaults/bounds and moves bounds into descriptions. | |
| - Inlines local $refs and collapses anyOf/oneOf/allOf to a single branch. | |
| - (Optionally) sets additionalProperties:false to curb hallucinated keys. | |
| - Preserves Pydantic field order (properties and 'required' list). | |
| - ERRORS on recursive models (OpenAI doesn't support them) | |
| """ | |
| from copy import deepcopy | |
| from typing import Any, Dict, get_args, get_origin | |
| from loguru import logger | |
| from pydantic import BaseModel | |
| DEBUG = 0 | |
| if DEBUG: | |
| logger_trace = logger.debug | |
| else: | |
| logger_trace = lambda *args, **kwargs: None # noqa: E731 | |
| # ----------------------------- Preflight helpers ----------------------------- | |
| def _is_optional_annotation(tp) -> bool: | |
| """ | |
| True for Optional[T] or (T | None) on all Python versions. | |
| Works for typing.Union and PEP 604 unions (types.UnionType). | |
| """ | |
| try: | |
| return type(None) in get_args(tp) | |
| except Exception: | |
| return False | |
| def _iter_submodels_from_annotation(tp, visited=None): | |
| """ | |
| Yield BaseModel subclasses found inside typing annotations (recurse Lists, Unions, etc.). | |
| """ | |
| if visited is None: | |
| visited = set() | |
| if tp is None: | |
| return | |
| # Prevent infinite recursion for self-referential types | |
| type_id = id(tp) | |
| if type_id in visited: | |
| return | |
| visited.add(type_id) | |
| origin = get_origin(tp) | |
| if origin is None: | |
| # bare type | |
| if isinstance(tp, type) and issubclass(tp, BaseModel): | |
| yield tp | |
| return | |
| args = get_args(tp) | |
| for a in args: | |
| if isinstance(a, type) and issubclass(a, BaseModel): | |
| yield a | |
| else: | |
| yield from _iter_submodels_from_annotation(a, visited) | |
| def _check_for_recursion(model_class, visited=None, path=None): | |
| """ | |
| Check if a model contains recursive references. | |
| Raises ValueError if recursion is detected. | |
| """ | |
| if visited is None: | |
| visited = set() | |
| if path is None: | |
| path = [] | |
| model_id = id(model_class) | |
| if model_id in visited: | |
| # Found recursion - build error message | |
| cycle_path = " -> ".join(path + [model_class.__name__]) | |
| raise ValueError( | |
| f"Recursive model detected: {model_class.__name__}\n" | |
| f"Recursion path: {cycle_path}\n" | |
| f"OpenAI's structured outputs do NOT support recursive schemas or $ref.\n" | |
| f"Consider redesigning your model to avoid recursion, such as:\n" | |
| f" - Using a flat structure with IDs to reference related objects\n" | |
| f" - Limiting nesting to a fixed depth with different model types\n" | |
| f" - Using a string field to store serialized nested data" | |
| ) | |
| visited.add(model_id) | |
| new_path = path + [model_class.__name__] | |
| for field_name, field_info in model_class.model_fields.items(): | |
| annotation = field_info.annotation | |
| # Check if the field references other BaseModel subclasses | |
| for submodel in _iter_submodels_from_annotation(annotation): | |
| _check_for_recursion(submodel, visited.copy(), new_path) | |
| def _preflight_model_consistency( | |
| model: type[BaseModel], | |
| *, | |
| errors: list[str], | |
| warnings: list[str], | |
| path: str = "", | |
| warn_on_nullable_description_mismatch: bool = True, | |
| visited: set = None, | |
| ) -> None: | |
| """ | |
| Collect preflight issues that cannot be fully 'fixed' by schema sanitization alone. | |
| - ERROR: Non-optional type with default=None | |
| - WARN : 'if ... then ... else' description (OpenAI subset can't encode) | |
| - WARN : description mentions 'null/none' but field is non-optional (likely non-nullable in strict schema) | |
| Recurse into nested BaseModels discovered via annotations. | |
| """ | |
| if visited is None: | |
| visited = set() | |
| # Prevent infinite recursion on self-referential models | |
| if model in visited: | |
| return | |
| visited.add(model) | |
| prefix = f"{path}." if path else "" | |
| for name, f in model.model_fields.items(): | |
| ann = getattr(f, "annotation", None) | |
| default = getattr(f, "default", ...) | |
| desc = (getattr(f, "description", "") or "").lower() | |
| # 1) ERROR: Non-optional annotation but default=None (Pydantic will reject None at runtime) | |
| if default is None and ann is not None and not _is_optional_annotation(ann): | |
| errors.append( | |
| f"{prefix}{model.__name__}.{name}: non-optional type {getattr(ann, '__name__', ann)} " | |
| f"has default=None (unsatisfiable without changing model or post-processing)." | |
| ) | |
| # 2) WARN: if/then/else prose requirements | |
| if " if " in f" {desc} " and " then " in f" {desc} " and " else " in f" {desc} ": | |
| warnings.append( | |
| f"{prefix}{model.__name__}.{name}: description encodes conditional requirement " | |
| f"('if/then/else'); OpenAI strict JSON Schema cannot express that." | |
| ) | |
| # 3) WARN: 'null' or 'none' mentioned in description but field is non-optional | |
| if warn_on_nullable_description_mismatch and ((" null" in f" {desc}") or (" none" in f" {desc}")): | |
| if ann is not None and not _is_optional_annotation(ann): | |
| warnings.append( | |
| f"{prefix}{model.__name__}.{name}: description mentions null/none but annotation is non-optional; " | |
| f"field will be non-nullable in strict schema." | |
| ) | |
| # Recurse into submodels referenced by this field's type | |
| for sub in _iter_submodels_from_annotation(ann): | |
| _preflight_model_consistency( | |
| sub, | |
| errors=errors, | |
| warnings=warnings, | |
| path=f"{prefix}{model.__name__}.{name}", | |
| warn_on_nullable_description_mismatch=warn_on_nullable_description_mismatch, | |
| visited=visited, | |
| ) | |
| # --------------------------- Schema sanitizer core --------------------------- | |
| def sanitize_for_openai_schema( | |
| model: type[BaseModel], | |
| *, | |
| require_all: bool = False, | |
| make_optionals_required_nullable: bool = True, | |
| strip_defaults: bool = True, | |
| strip_numeric_bounds: bool = True, | |
| move_bounds_to_description: bool = True, | |
| inline_refs: bool = True, | |
| disallow_additional_props: bool = True, | |
| # NEW: diagnostics | |
| error_on_preflight: bool = True, | |
| warn_on_nullable_description_mismatch: bool = True, | |
| # NEW: behavior for enum unions | |
| merge_enums_even_with_string: bool = True, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Produce an OpenAI-structured-outputs-friendly JSON Schema from a Pydantic model. | |
| Enhancements: | |
| - Preflight: errors/warnings for unsatisfiable or risky patterns (see parameters). | |
| - Optionals → required + nullable (OpenAI strict style). | |
| - Union collapse: merges enum branches; if a 'string' branch exists and merge_enums_even_with_string=True, | |
| we still merge enums and discard the free-form string branch (preserving nullability), and annotate. | |
| - Inlines local $refs and collapses anyOf/oneOf/allOf to a single branch. | |
| - (Optionally) sets additionalProperties:false to curb hallucinated keys. | |
| - Preserves Pydantic field order (properties and 'required' list). | |
| - ERRORS on recursive models (OpenAI doesn't support them) | |
| """ | |
| # Check for recursive models FIRST - fail fast | |
| try: | |
| _check_for_recursion(model) | |
| except ValueError as e: | |
| logger.error(f"Schema validation failed: {e}") | |
| raise | |
| # -------- Preflight consistency checks -------- | |
| errors: list[str] = [] | |
| warnings: list[str] = [] | |
| _preflight_model_consistency( | |
| model, | |
| errors=errors, | |
| warnings=warnings, | |
| warn_on_nullable_description_mismatch=warn_on_nullable_description_mismatch, | |
| ) | |
| for w in warnings: | |
| logger.warning(f"[schema preflight] {w}") | |
| if errors: | |
| for e in errors: | |
| logger.error(f"[schema preflight] {e}") | |
| if error_on_preflight: | |
| raise ValueError("Schema/model consistency issues:\n- " + "\n- ".join(errors)) | |
| # -------- Generate raw schema and sanitize -------- | |
| raw = model.model_json_schema() | |
| schema = deepcopy(raw) | |
| defs = schema.get("$defs") or schema.get("definitions") or {} | |
| def _append_desc(node: Dict[str, Any], extra: str) -> None: | |
| if not extra: | |
| return | |
| if "description" in node and isinstance(node["description"], str): | |
| node["description"] = node["description"].rstrip() + f"\n\n{extra}" | |
| else: | |
| node["description"] = extra | |
| def _strip_bounds_and_move_description(node: Dict[str, Any]) -> None: | |
| if not strip_numeric_bounds: | |
| return | |
| string_constraints = [] | |
| numeric_constraints = [] | |
| # Strip string constraints (pattern, minLength, maxLength) | |
| for k in ("pattern", "minLength", "maxLength"): | |
| if k in node: | |
| val = node.pop(k) | |
| string_constraints.append(f"{k}={val}") | |
| # Numeric constraints | |
| for k in ("minimum", "exclusiveMinimum", "maximum", "exclusiveMaximum"): | |
| if k in node: | |
| val = node.pop(k) | |
| numeric_constraints.append(f"{k.replace('exclusive', 'exclusive ')}={val}") | |
| # Add descriptions for different types of constraints | |
| if move_bounds_to_description: | |
| if string_constraints: | |
| _append_desc(node, "String constraints (enforced app-side): " + ", ".join(string_constraints) + ".") | |
| if numeric_constraints: | |
| _append_desc(node, "Numeric constraints (enforced app-side): " + ", ".join(numeric_constraints) + ".") | |
| def _ensure_nullable_type(node: Dict[str, Any]) -> None: | |
| """ | |
| Ensure node['type'] includes 'null' (do NOT add null into 'enum'). | |
| """ | |
| t = node.get("type") | |
| if "enum" in node and not t: | |
| # infer primitive type for enum | |
| if all(isinstance(v, (int, float)) for v in node["enum"]): | |
| node["type"] = "number" | |
| else: | |
| node["type"] = "string" | |
| t = node["type"] | |
| if isinstance(t, list): | |
| if "null" not in t: | |
| node["type"] = t + ["null"] | |
| elif isinstance(t, str): | |
| if t != "null": | |
| node["type"] = [t, "null"] | |
| else: | |
| node["type"] = ["string", "null"] | |
| else: | |
| node["type"] = ["string", "null"] | |
| def _get_enum_type(enum_values): | |
| """ | |
| Determine the type of enum values. | |
| """ | |
| if all(isinstance(v, bool) for v in enum_values): | |
| return "boolean" | |
| elif all(isinstance(v, int) and not isinstance(v, bool) for v in enum_values): | |
| return "integer" | |
| elif all(isinstance(v, (int, float)) and not isinstance(v, bool) for v in enum_values): | |
| return "number" | |
| else: | |
| return "string" | |
| def _inline_ref(node: Dict[str, Any], visited_refs=None) -> Dict[str, Any]: | |
| if visited_refs is None: | |
| visited_refs = set() | |
| ref = node.get("$ref") | |
| if not ref or not inline_refs: | |
| return node | |
| # We should never hit recursive refs because we checked earlier | |
| # But if we somehow do, error out | |
| if ref in visited_refs: | |
| raise ValueError( | |
| f"Unexpected recursive reference found during inlining: {ref}\n" | |
| f"This should have been caught earlier. The model likely has recursive structure " | |
| f"which is not supported by OpenAI's structured outputs." | |
| ) | |
| if not ref.startswith("#/"): | |
| node.pop("$ref", None) # external refs unsupported | |
| return node | |
| visited_refs_copy = visited_refs.copy() | |
| visited_refs_copy.add(ref) | |
| parts = ref.lstrip("#/").split("/") | |
| cur: Any = schema | |
| for p in parts: | |
| if isinstance(cur, dict) and p in cur: | |
| cur = cur[p] | |
| else: | |
| cur = defs.get(p, {}) | |
| inlined = deepcopy(cur) if isinstance(cur, dict) else {} | |
| # Preserve any additional properties from the original node (like description) | |
| for key in node: | |
| if key != "$ref" and key not in inlined: | |
| inlined[key] = node[key] | |
| return _walk(inlined, visited_refs=visited_refs_copy) | |
| def _collapse_alternatives(node: Dict[str, Any]) -> None: | |
| """ | |
| Handle anyOf/oneOf/allOf: | |
| - Merge enum branches into a single enum (even if a 'string' branch is present, when enabled). | |
| - Preserve nullability if any alt is {"type":"null"}. | |
| - Otherwise keep first alternative and annotate. | |
| """ | |
| for key in ("anyOf", "oneOf", "allOf"): | |
| alts = node.get(key) | |
| if not (isinstance(alts, list) and alts): | |
| continue | |
| merged_enums: list[Any] = [] | |
| enum_type: str | None = None | |
| saw_null = False | |
| saw_string_branch = False | |
| resolved: list[Dict[str, Any]] = [] | |
| def _resolve(a: Dict[str, Any]) -> Dict[str, Any]: | |
| return _inline_ref(a) if (isinstance(a, dict) and "$ref" in a) else a | |
| for a in alts: | |
| a = a if isinstance(a, dict) else {} | |
| a = _resolve(a) | |
| resolved.append(a) | |
| t = a.get("type") | |
| if t == "null": | |
| saw_null = True | |
| continue | |
| if "enum" in a: | |
| current_enum_type = _get_enum_type(a["enum"]) | |
| # Only merge enums if they're compatible types or if we haven't set a type yet | |
| if not enum_type: | |
| enum_type = current_enum_type | |
| merged_enums.extend(a["enum"]) | |
| elif enum_type == current_enum_type: | |
| merged_enums.extend(a["enum"]) | |
| elif enum_type in ["integer", "number"] and current_enum_type in ["integer", "number"]: | |
| # Allow merging integer and number enums, use number as the type | |
| enum_type = "number" | |
| merged_enums.extend(a["enum"]) | |
| # Otherwise skip incompatible enum types | |
| continue | |
| if t == "string": | |
| saw_string_branch = True | |
| # non-enum branch remains | |
| node.pop(key, None) | |
| if merged_enums: | |
| node.pop("$ref", None) | |
| # If there's a string branch and the toggle is on, we still merge enums and discard free-form string | |
| node["type"] = enum_type or "string" | |
| # Don't sort enums, preserve order and uniqueness | |
| seen = set() | |
| unique_vals = [] | |
| for v in merged_enums: | |
| if v not in seen: | |
| seen.add(v) | |
| unique_vals.append(v) | |
| node["enum"] = unique_vals | |
| if saw_null: | |
| _ensure_nullable_type(node) | |
| if saw_string_branch and merge_enums_even_with_string: | |
| _append_desc( | |
| node, | |
| "Note: union included a free-form string branch; sanitizer merged enum branches " | |
| "and discarded the open string to keep values constrained.", | |
| ) | |
| return | |
| # Fallback: keep first resolved alt (and preserve nullability if present) | |
| first = resolved[0] if resolved else {} | |
| for k, v in first.items(): | |
| if k not in node: | |
| node[k] = v | |
| if any(a.get("type") == "null" for a in resolved): | |
| _ensure_nullable_type(node) | |
| _append_desc( | |
| node, | |
| f"Note: original schema had `{key}` with {len(alts)} alternatives; " | |
| f"sanitizer kept the first and preserved nullability.", | |
| ) | |
| return | |
| # Track warnings for "null/none mentioned but field is non-nullable" during walking | |
| def _walk(node: Any, path: str = "<root>", visited_refs=None) -> Any: | |
| if visited_refs is None: | |
| visited_refs = set() | |
| logger_trace(f"[ENTER] {path}: dict={isinstance(node, dict)}") | |
| if isinstance(node, dict): | |
| if "$ref" in node and inline_refs: | |
| logger_trace(f"[{path}] Resolving $ref: {node['$ref']}") | |
| return _inline_ref(node, visited_refs) | |
| # Strip defaults / format / examples etc. | |
| if strip_defaults and "default" in node: | |
| node.pop("default", None) | |
| if "format" in node: | |
| node.pop("format", None) | |
| for k in ("examples", "example", "readOnly", "writeOnly", "deprecated"): | |
| node.pop(k, None) | |
| # Convert "const" to "enum" for compatibility | |
| if "const" in node: | |
| const_value = node.pop("const") | |
| node["enum"] = [const_value] | |
| # Ensure type is set if not already | |
| if "type" not in node: | |
| if isinstance(const_value, bool): | |
| node["type"] = "boolean" | |
| elif isinstance(const_value, int): | |
| node["type"] = "integer" | |
| elif isinstance(const_value, float): | |
| node["type"] = "number" | |
| else: | |
| node["type"] = "string" | |
| _collapse_alternatives(node) # may rewrite node in place | |
| _strip_bounds_and_move_description(node) | |
| # Check if this is an array type (including nullable arrays) | |
| node_type = node.get("type") | |
| is_array = False | |
| if node_type == "array": | |
| is_array = True | |
| elif isinstance(node_type, list) and "array" in node_type: | |
| is_array = True | |
| if is_array: | |
| if "items" in node: | |
| node["items"] = _walk(node["items"], f"{path}.items", visited_refs) | |
| node.setdefault("minItems", 0) | |
| # Don't return here, continue processing | |
| # Objects - handle both pure object and nullable object types | |
| is_object = False | |
| is_nullable_object = False | |
| if node_type == "object": | |
| is_object = True | |
| elif isinstance(node_type, list) and "object" in node_type: | |
| is_object = True | |
| is_nullable_object = "null" in node_type | |
| if is_object: | |
| props = node.get("properties", {}) or {} | |
| # Preserve original required (from Pydantic) to judge optionality | |
| original_required = set(node.get("required", [])) | |
| # Walk children first | |
| for k in list(props.keys()): | |
| props[k] = _walk(props[k], f"{path}.properties[{k}]", visited_refs) | |
| node["properties"] = props | |
| # Warn if description mentions null/none for a non-optional field (simple heuristic) | |
| if warn_on_nullable_description_mismatch and original_required: | |
| for k in original_required: | |
| child = props.get(k) | |
| if not isinstance(child, dict): | |
| continue | |
| desc = (child.get("description") or "").lower() | |
| if (" null" in f" {desc}") or (" none" in f" {desc}"): | |
| # This field is non-optional in Pydantic, so we won't add nullability | |
| logger.warning( | |
| f"[schema warning] {path}.properties[{k}]: description mentions null/none " | |
| f"but field is required by Pydantic (non-nullable in strict schema)." | |
| ) | |
| # Start from Pydantic's required list | |
| required = list(original_required) | |
| # Emulate optionals as required+nullable | |
| if make_optionals_required_nullable and props: | |
| for k, p in props.items(): | |
| if k not in original_required: | |
| required.append(k) | |
| _ensure_nullable_type(p) | |
| # If global require_all, override | |
| if require_all and props: | |
| required = list(props.keys()) | |
| if required: | |
| ordered = [k for k in props.keys() if k in required] | |
| node["required"] = ordered | |
| # Handle additionalProperties - CRITICAL FIX for nullable objects | |
| if "additionalProperties" in node: | |
| add_props = node["additionalProperties"] | |
| # If it's a schema (dict), ensure it has proper handling | |
| if isinstance(add_props, dict): | |
| # Walk the additionalProperties schema first | |
| add_props = _walk(add_props, f"{path}.additionalProperties", visited_refs) | |
| # For nullable object types, OpenAI requires additionalProperties: false | |
| if is_nullable_object: | |
| node["additionalProperties"] = False | |
| else: | |
| # Handle different cases for additionalProperties | |
| if not add_props: | |
| # Empty dict - needs a type for OpenAI | |
| add_props = {"type": "string"} | |
| elif "type" not in add_props and not add_props.get("$ref"): | |
| # Has content but no type - infer or default | |
| if "properties" in add_props: | |
| add_props["type"] = "object" | |
| elif "items" in add_props: | |
| add_props["type"] = "array" | |
| elif "enum" in add_props: | |
| # Infer type from enum values | |
| add_props["type"] = _get_enum_type(add_props["enum"]) | |
| else: | |
| # Default to string for any untyped schema | |
| add_props["type"] = "string" | |
| node["additionalProperties"] = add_props | |
| elif add_props is True: | |
| # For nullable objects, must be false | |
| if is_nullable_object: | |
| node["additionalProperties"] = False | |
| else: | |
| # Convert true to a schema with type | |
| node["additionalProperties"] = {"type": "string"} | |
| # If false, leave as-is | |
| elif disallow_additional_props: | |
| # Always set to false for nullable objects | |
| if is_nullable_object: | |
| node["additionalProperties"] = False | |
| elif props: | |
| # Structured object with properties | |
| node["additionalProperties"] = False | |
| else: | |
| # No properties means this is a free-form dict like Dict[str, Any] | |
| # For non-nullable free-form dicts, allow with typed schema | |
| node["additionalProperties"] = {"type": "string"} | |
| return node | |
| # Recurse other dict fields | |
| for k, v in list(node.items()): | |
| node[k] = _walk(v, f"{path}.{k}", visited_refs) | |
| return node | |
| elif isinstance(node, list): | |
| return [_walk(v, f"{path}[{i}]", visited_refs) for i, v in enumerate(node)] | |
| return node | |
| sanitized = _walk(schema) | |
| # Remove defs after inlining | |
| sanitized.pop("$defs", None) | |
| sanitized.pop("definitions", None) | |
| # Top-level object finalization + root-level required+nullable catch-up | |
| if sanitized.get("type") == "object": | |
| props = sanitized.get("properties", {}) or {} | |
| if require_all and props: | |
| sanitized["required"] = list(props.keys()) | |
| else: | |
| if make_optionals_required_nullable and props: | |
| req = list(sanitized.get("required", [])) | |
| present = set(req) | |
| changed = False | |
| for k in props.keys(): | |
| if k not in present: | |
| req.append(k) | |
| _ensure_nullable_type(props[k]) | |
| changed = True | |
| if changed: | |
| sanitized["required"] = req | |
| # Final handling of additionalProperties at root level | |
| if "additionalProperties" in sanitized: | |
| add_props = sanitized["additionalProperties"] | |
| # Ensure it has proper type if it's a dict | |
| if isinstance(add_props, dict): | |
| if not add_props: | |
| # Empty dict needs a type | |
| sanitized["additionalProperties"] = {"type": "string"} | |
| elif "type" not in add_props and not add_props.get("$ref"): | |
| # Has content but no type | |
| if "properties" in add_props: | |
| add_props["type"] = "object" | |
| else: | |
| add_props["type"] = "string" | |
| sanitized["additionalProperties"] = add_props | |
| elif disallow_additional_props: | |
| if props: | |
| sanitized["additionalProperties"] = False | |
| else: | |
| # Empty model needs typed additionalProperties for OpenAI | |
| sanitized["additionalProperties"] = {"type": "string"} | |
| return sanitized | |
| # -------- Usage examples -------- | |
| # 1) OpenAI SDK (Responses or Chat Completions with structured outputs) | |
| # schema = sanitize_for_openai_schema(MessageQualityPromptSchema) | |
| # client.responses.create( | |
| # model="gpt-4o-mini", | |
| # input="...", # your prompt/messages | |
| # response_format={"type": "json_schema", | |
| # "json_schema": {"name": "MessageQualityPromptSchema", | |
| # "schema": schema}}, | |
| # ) | |
| # 2) LangChain: bind the raw response_format | |
| # from langchain_openai import ChatOpenAI | |
| # llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
| # schema = sanitize_for_openai_schema(MessageQualityPromptSchema) | |
| # llm = llm.bind(response_format={"type":"json_schema", | |
| # "json_schema":{"name":"MessageQualityPromptSchema","schema":schema}}) | |
| # result = llm.invoke("...your input...") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # tests/test_sanitizer_and_safe_structured_real.py | |
| import enum | |
| import os | |
| from typing import Any, Dict, List, Literal, Optional, Union | |
| import pytest | |
| import pytest_asyncio | |
| from dotenv import load_dotenv | |
| from pydantic import BaseModel, Field | |
| from gv.ai.common.llm.json_schema import sanitize_for_openai_schema | |
| from gv.ai.common.llm.llm_chain_base import safe_with_structured_output | |
| from gv.ai.common.llm.llm_client_builder import get_langchain_model_and_token_counter | |
| load_dotenv() | |
| # ------------------ Small helpers for strict JSON Schema expectations ------------------ | |
| def _types(node: Dict[str, Any]) -> set[str]: | |
| t = node.get("type") | |
| return set(t) if isinstance(t, list) else {t} if t is not None else set() | |
| def _assert_base_type(node: Dict[str, Any], expected: str): | |
| # Accept required+nullable by ignoring "null" | |
| assert expected in (_types(node) - {"null"}), f"expected base type {expected}, got {_types(node)}" | |
| def _assert_nullable(node: Dict[str, Any], should_be_nullable: bool = True): | |
| has_null = "null" in _types(node) | |
| assert has_null if should_be_nullable else not has_null | |
| # ------------------ Models ------------------ | |
| class MyEnum(str, enum.Enum): | |
| """ | |
| Example enum with two values. | |
| """ | |
| ALPHA = "alpha" | |
| BETA = "beta" | |
| class InnerModel(BaseModel): | |
| x: int = Field(..., description="X coordinate in integer form.") | |
| y: float = Field(..., ge=0.0, le=1.0, description="Y coordinate (normalized between 0.0 and 1.0).") | |
| class BadModelWithEnumAndInner(BaseModel): | |
| a: Optional[str] = Field(None, description="Optional text field to test nullable handling.") | |
| e: MyEnum = Field(..., description="An enum field that should retain allowed values.") | |
| inner: InnerModel = Field(..., description="A nested InnerModel with numeric constraints.") | |
| c: int | float = Field(0, description="A number that could be int or float (union).") | |
| action: Literal["X", "Y", "Z"] = Field("X", description="A literal value that must be one of: 'X', 'Y', or 'Z'.") | |
| # ------------------ Offline schema tests ------------------ | |
| def test_sanitizer_handles_enum_inner_and_bounds(): | |
| schema = sanitize_for_openai_schema(BadModelWithEnumAndInner) | |
| props = schema["properties"] | |
| # All fields required (optionals become required+nullable) | |
| assert schema["required"] == ["a", "e", "inner", "c", "action"] | |
| # Enum preserved, and since 'e' is required (no optional/default), it should NOT be nullable | |
| assert "enum" in props["e"] | |
| _assert_base_type(props["e"], "string") | |
| _assert_nullable(props["e"], False) | |
| # Inner model flattened OK, bounds removed but described | |
| inner_props = props["inner"]["properties"] | |
| assert list(inner_props.keys()) == ["x", "y"] | |
| assert "minimum" not in inner_props["y"] | |
| assert "maximum" not in inner_props["y"] | |
| # FIX: Check for "constraints" (case-insensitive) instead of "numeric constraints" | |
| desc_lower = inner_props["y"]["description"].lower() | |
| assert "constraints" in desc_lower, f"Expected 'constraints' in description: {inner_props['y']['description']}" | |
| # ------------------ Real OpenAI invocation test ------------------ | |
| need_openai = pytest.mark.skipif( | |
| not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set; skipping real OpenAI test." | |
| ) | |
| @pytest_asyncio.fixture(scope="module") | |
| @need_openai | |
| async def real_chain(): | |
| """ | |
| Create a real OpenAI chain using safe_with_structured_output. | |
| """ | |
| model, _ = get_langchain_model_and_token_counter("gpt-4o-mini") | |
| chain = safe_with_structured_output(model, BadModelWithEnumAndInner, prefer_json_schema=True) | |
| return chain | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @need_openai | |
| async def test_real_openai_invocation(real_chain): | |
| """ | |
| Sends a prompt that should produce a valid BadModelWithEnumAndInner and ensures the output parses. | |
| """ | |
| prompt = ( | |
| "Return a JSON object matching the following:\n" | |
| "- a: some string\n" | |
| "- e: 'beta'\n" | |
| "- inner: {x: 42, y: 0.5}\n" | |
| "- c: 3.14\n" | |
| "- action: 'Y'\n" | |
| ) | |
| out = await real_chain.ainvoke(prompt) | |
| assert isinstance(out, BadModelWithEnumAndInner) | |
| assert out.a | |
| assert out.e == MyEnum.BETA | |
| assert out.inner.x == 42 | |
| assert 0 <= out.inner.y <= 1 | |
| assert isinstance(out.c, (int, float)) | |
| assert out.action in {"X", "Y", "Z"} | |
| # ------------------ Models for edge cases ------------------ | |
| class TitleEnum(str, enum.Enum): | |
| """ | |
| Short enum summary for the model (title helps the LLM). | |
| """ | |
| ONE = "one" | |
| TWO = "two" | |
| class InnerEdge(BaseModel): | |
| first: int = Field(..., description="First comes first.") | |
| second: float = Field(0.5, ge=0.0, le=1.0, description="Second is a normalized float. Defaults handled app-side.") | |
| class EdgeCaseModel(BaseModel): | |
| # Optionals in every flavor | |
| opt_none_str: Optional[str] = Field(None, description="Optional string, may be null.") | |
| opt_with_default_num: Optional[int] = Field(5, description="Optional int with non-None default (5).") | |
| opt_factory_list: Optional[List[int]] = Field( | |
| default_factory=list, description="Optional list with default_factory." | |
| ) | |
| # Defaults of various types | |
| d_str: str = Field("", description="Empty string default (apply app-side).") | |
| d_int: int = Field(0, description="Zero default (apply app-side).") | |
| d_float: float = Field(0.0, description="Zero float default (apply app-side).") | |
| d_bool: bool = Field(False, description="False default (apply app-side).") | |
| d_list: List[str] = Field(default_factory=list, description="List default via factory.") | |
| d_dict: Dict[str, Any] = Field(default_factory=dict, description="Dict default via factory.") | |
| # Enums and Literals | |
| e1: TitleEnum = Field(TitleEnum.ONE, description="Enum field.") | |
| e2: Literal["A", "B", "C"] = Field("A", description="Literal field with three values.") | |
| # Nested & unions | |
| inner: InnerEdge = Field(..., description="Nested model with constraints.") | |
| union_scalar: Union[int, float] = Field(1, description="Union of number types (collapses to a single type).") | |
| # Free-form maps + typed maps | |
| metadata_loose: Dict[str, Any] = Field(default_factory=dict, description="Free-form map, typed 'any'.") | |
| metadata_typed: Dict[str, int] = Field(default_factory=dict, description="Typed map of string->int.") | |
| # Arrays with constraints (Pydantic emits some) | |
| tags: List[str] = Field(default_factory=list, description="An array of tags.") | |
| # ------------------ Sanitizer: Field order & required ------------------ | |
| def test_field_order_and_required_preserved_default(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| props = sch["properties"] | |
| assert list(props.keys()) == [ | |
| "opt_none_str", | |
| "opt_with_default_num", | |
| "opt_factory_list", | |
| "d_str", | |
| "d_int", | |
| "d_float", | |
| "d_bool", | |
| "d_list", | |
| "d_dict", | |
| "e1", | |
| "e2", | |
| "inner", | |
| "union_scalar", | |
| "metadata_loose", | |
| "metadata_typed", | |
| "tags", | |
| ], "Property order must match class declaration" | |
| assert sch["required"] == list(props.keys()), "All required in same order" | |
| def test_nested_field_order_and_required_preserved(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| inner = sch["properties"]["inner"] | |
| assert inner["type"] == "object" | |
| assert list(inner["properties"].keys()) == ["first", "second"] | |
| assert inner["required"] == ["first", "second"] | |
| # ------------------ Optionals & nullability notes ------------------ | |
| def test_optionals_become_required_and_nullable_noted(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| p = sch["properties"]["opt_none_str"] | |
| # Required at schema-level | |
| assert "opt_none_str" in sch["required"] | |
| # Type should be required+nullable (OpenAI strict); base type string | |
| _assert_base_type(p, "string") | |
| _assert_nullable(p, True) | |
| # (Optional) description guidance is nice-to-have but not strictly required by the sanitizer | |
| def test_optionals_with_non_none_defaults_and_factory(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| p_num = sch["properties"]["opt_with_default_num"] | |
| p_list = sch["properties"]["opt_factory_list"] | |
| # No 'default' should remain | |
| assert "default" not in p_num and "default" not in p_list | |
| # Required+nullable with correct base types | |
| _assert_base_type(p_num, "integer") # or "number" - integer is fine since it's a subset | |
| _assert_nullable(p_num, True) | |
| _assert_base_type(p_list, "array") | |
| _assert_nullable(p_list, True) | |
| # ------------------ Defaults of all kinds removed ------------------ | |
| @pytest.mark.parametrize("name", ["d_str", "d_int", "d_float", "d_bool", "d_list", "d_dict"]) | |
| def test_defaults_removed_for_various_types(name): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| assert "default" not in sch["properties"][name], f"default must be stripped for {name}" | |
| # ------------------ Enums & titles kept, literals ok ------------------ | |
| def test_enums_and_literals_inline_and_titles_present(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| e1 = sch["properties"]["e1"] | |
| e2 = sch["properties"]["e2"] | |
| # These have defaults in Pydantic → optional → required+nullable in strict mode | |
| _assert_base_type(e1, "string") | |
| _assert_nullable(e1, True) | |
| assert "enum" in e1 | |
| _assert_base_type(e2, "string") | |
| _assert_nullable(e2, True) | |
| assert "enum" in e2 | |
| # Titles are produced by Pydantic - sanitizer should not remove them | |
| assert "title" in e1 and "title" in e2 | |
| # ------------------ Numeric bounds stripped & noted ------------------ | |
| def test_numeric_bounds_stripped_and_documented(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| inner_y = sch["properties"]["inner"]["properties"]["second"] | |
| assert "minimum" not in inner_y and "maximum" not in inner_y | |
| # FIX: Check for "constraints" (case-insensitive) instead of "numeric constraints" | |
| desc_lower = inner_y.get("description", "").lower() | |
| assert "constraints" in desc_lower, f"Expected 'constraints' in description: {inner_y.get('description', '')}" | |
| # ------------------ Union collapse ------------------ | |
| def test_union_collapse_scalar(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| u = sch["properties"]["union_scalar"] | |
| assert "anyOf" not in u and "oneOf" not in u and "allOf" not in u | |
| assert "type" in u # keep a single usable type | |
| # This field has a default (optional) → required+nullable number (int branch is fine) | |
| _assert_base_type(u, "integer") # or "number" | |
| _assert_nullable(u, True) | |
| # ------------------ $ref inlining & $defs removal ------------------ | |
| def test_refs_inlined_and_defs_removed(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| assert "$defs" not in sch and "definitions" not in sch | |
| # inner should be an object with properties (not a $ref) | |
| inner = sch["properties"]["inner"] | |
| assert inner.get("type") == "object" and "properties" in inner | |
| # ------------------ additionalProperties handling ------------------ | |
| def test_additional_properties_false_when_enabled_top_and_nested(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel, disallow_additional_props=True) | |
| # Top-level strict | |
| assert sch.get("additionalProperties") is False | |
| # Typed map → object schema with additionalProperties as a dict schema (NOT False) | |
| typed = sch["properties"]["metadata_typed"] | |
| _assert_base_type(typed, "object") | |
| _assert_nullable(typed, True) | |
| assert isinstance(typed.get("additionalProperties"), dict) | |
| # Loose map → also a free-form object - sanitizer leaves additionalProperties as {} (dict), not False | |
| loose = sch["properties"]["metadata_loose"] | |
| _assert_base_type(loose, "object") | |
| _assert_nullable(loose, True) | |
| assert isinstance(loose.get("additionalProperties"), dict) # accept {} or a schema dict | |
| # ------------------ Arrays: basic sanity ------------------ | |
| def test_arrays_items_sane_and_minimal(): | |
| sch = sanitize_for_openai_schema(EdgeCaseModel) | |
| tags = sch["properties"]["tags"] | |
| _assert_base_type(tags, "array") | |
| _assert_nullable(tags, True) # default_factory → optional → required+nullable | |
| assert "items" in tags and tags["items"].get("type") == "string" | |
| assert tags.get("minItems", 0) == 0 | |
| # ------------------ Real invocation with OpenAI (skip if no key) ------------------ | |
| @pytest_asyncio.fixture(scope="module") | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def real_chain_edge(): | |
| model, _ = get_langchain_model_and_token_counter("gpt-4o-mini") | |
| # Use json_schema path | |
| chain = safe_with_structured_output(model, EdgeCaseModel, prefer_json_schema=True) | |
| return chain | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_edge(real_chain_edge): | |
| prompt = ( | |
| "Return JSON with all fields:\n" | |
| "- opt_none_str: 'hello'\n" | |
| "- opt_with_default_num: 9\n" | |
| "- opt_factory_list: [1,2,3]\n" | |
| "- d_str: 's'\n" | |
| "- d_int: 1\n" | |
| "- d_float: 0.25\n" | |
| "- d_bool: true\n" | |
| "- d_list: ['a','b']\n" | |
| "- d_dict: {k: 'v'}\n" | |
| "- e1: 'two'\n" | |
| "- e2: 'B'\n" | |
| "- inner: {first: 7, second: 0.7}\n" | |
| "- union_scalar: 3.14\n" | |
| "- metadata_loose: {anyKey: 123}\n" | |
| "- metadata_typed: {count: 2}\n" | |
| "- tags: ['x','y']\n" | |
| ) | |
| out = await real_chain_edge.ainvoke(prompt) | |
| assert isinstance(out, EdgeCaseModel) | |
| assert out.opt_none_str == "hello" | |
| assert out.e1 == TitleEnum.TWO | |
| assert out.inner.first == 7 and 0 <= out.inner.second <= 1 | |
| assert isinstance(out.union_scalar, (int, float)) | |
| assert isinstance(out.metadata_loose, dict) | |
| assert isinstance(out.metadata_typed.get("count"), int) | |
| assert out.tags == ["x", "y"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Additional test cases for edge cases and OpenAI JSON Schema requirements that might be missing from your current test | |
| suite. | |
| """ | |
| import json | |
| from datetime import date, datetime | |
| from decimal import Decimal | |
| from enum import Enum | |
| from typing import Any, Dict, List, Literal, Optional, Union | |
| import pytest | |
| from pydantic import ( | |
| BaseModel, | |
| Field, | |
| computed_field, | |
| confloat, | |
| conint, | |
| constr, | |
| create_model, | |
| field_validator, | |
| model_validator, | |
| ) | |
| from gv.ai.common.llm.llm_chain_base import safe_with_structured_output | |
| from gv.ai.common.llm.llm_client_builder import get_langchain_model_and_token_counter | |
| from .test_json_schema import need_openai | |
| INTEGRATION_MODELS = ["gpt-4.1-nano", "gpt-5-nano"] | |
| @pytest.fixture | |
| def sanitizer(): | |
| """ | |
| Fixture that returns the sanitizer function. | |
| """ | |
| def _sanitize(model, **kwargs): | |
| from gv.ai.common.llm.json_schema import sanitize_for_openai_schema | |
| return sanitize_for_openai_schema( | |
| model, | |
| require_all=False, | |
| make_optionals_required_nullable=True, | |
| inline_refs=True, | |
| disallow_additional_props=True, | |
| **kwargs, | |
| ) | |
| return _sanitize | |
| # ============== 1. ENUM EDGE CASES ============== | |
| class MixedTypeEnum(str, Enum): | |
| """ | |
| OpenAI requires consistent enum types. | |
| """ | |
| STRING_VAL = "text" | |
| # This would be problematic if we had mixed types: | |
| # NUMBER_VAL = 123 # Can't mix str and int in same enum! | |
| class IntEnum(int, Enum): | |
| """ | |
| Integer enums need special handling. | |
| """ | |
| POSITIVE = 1 | |
| NEGATIVE = -1 | |
| ZERO = 0 | |
| class FloatEnum(float, Enum): | |
| """ | |
| Float enums are tricky. | |
| """ | |
| PI = 3.14159 | |
| E = 2.71828 | |
| class EnumEdgeCases(BaseModel): | |
| status: MixedTypeEnum | |
| score: IntEnum | |
| ratio: Optional[FloatEnum] = None | |
| def test_enum_type_consistency(sanitizer): | |
| """ | |
| OpenAI requires all enum values to be same type. | |
| """ | |
| schema = sanitizer(EnumEdgeCases) | |
| # Integer enum should have "type": "integer" or ["integer", ...] | |
| score = schema["properties"]["score"] | |
| assert "integer" in score.get("type", []) or score.get("type") == "integer" | |
| # Float enum should be "number" | |
| ratio = schema["properties"]["ratio"] | |
| types = ratio.get("type", []) | |
| assert "number" in types or types == "number" | |
| # ============== 2. LITERAL TYPES ============== | |
| class LiteralModel(BaseModel): | |
| # Literals are like single-value enums | |
| version: Literal["v1"] | |
| mode: Literal["production", "development"] | |
| # Complex literal union | |
| status: Optional[Literal["active", "inactive", "pending"]] = None | |
| def test_literals_become_enums_simple(sanitizer): | |
| """ | |
| Literals should convert to enum in schema. | |
| """ | |
| schema = sanitizer(LiteralModel) | |
| # Check what Pydantic actually generates for Literal types | |
| version_field = schema["properties"]["version"] | |
| # Pydantic might not always convert Literal to enum, it depends on the version | |
| # Let's check what we actually get | |
| if "enum" in version_field: | |
| assert version_field["enum"] == ["v1"] | |
| elif "const" in version_field: | |
| # Pydantic might use "const" for single-value Literals | |
| assert version_field["const"] == "v1" | |
| else: | |
| # If neither enum nor const, check the type at least | |
| assert version_field.get("type") == "string" | |
| # Log what we actually got for debugging | |
| print(f"Version field schema: {version_field}") | |
| # For multiple literals, Pydantic should create an enum | |
| mode_field = schema["properties"]["mode"] | |
| if "enum" in mode_field: | |
| assert set(mode_field["enum"]) == {"production", "development"} | |
| # Optional literals should be nullable | |
| status_field = schema["properties"]["status"] | |
| if "enum" in status_field: | |
| assert set(status_field["enum"]) >= {"active", "inactive", "pending"} | |
| assert "null" in status_field.get("type", []) | |
| def test_literals_become_enums(sanitizer): | |
| """ | |
| Literals should convert to enum in schema. | |
| Pydantic uses 'const' for single-value Literals and 'enum' for multi-value. Our sanitizer should convert 'const' to | |
| 'enum' for consistency. | |
| """ | |
| schema = sanitizer(LiteralModel) | |
| # Single literal: Pydantic generates "const", sanitizer should convert to "enum" | |
| version_field = schema["properties"]["version"] | |
| assert "enum" in version_field, f"Expected enum, got: {version_field}" | |
| assert version_field["enum"] == ["v1"] | |
| assert version_field.get("type") == "string" | |
| # Multiple literals: Pydantic generates "enum" directly | |
| mode_field = schema["properties"]["mode"] | |
| assert "enum" in mode_field | |
| assert set(mode_field["enum"]) == {"production", "development"} | |
| assert mode_field.get("type") == "string" | |
| # Optional literals: Pydantic generates anyOf[enum, null] | |
| # After sanitization, should be a single enum with nullable type | |
| status_field = schema["properties"]["status"] | |
| assert "enum" in status_field | |
| assert set(status_field["enum"]) == {"active", "inactive", "pending"} | |
| # Should be nullable since it's optional | |
| assert isinstance(status_field.get("type"), list) | |
| assert "string" in status_field["type"] | |
| assert "null" in status_field["type"] | |
| # ============== 3. NESTED OPTIONALS & DEEP STRUCTURES ============== | |
| class DeeplyNested(BaseModel): | |
| level1: Optional[Dict[str, Optional[List[Optional[str]]]]] = None | |
| # This is pathological but legal in Pydantic! | |
| class RecursiveModel(BaseModel): | |
| """ | |
| Self-referential models need special care. | |
| """ | |
| name: str | |
| children: Optional[List["RecursiveModel"]] = None | |
| # Update forward refs | |
| RecursiveModel.model_rebuild() | |
| class TestRecursiveModelCases: | |
| """ | |
| Test cases for recursive model handling. | |
| """ | |
| def test_recursive_models_are_rejected(self, sanitizer): | |
| """ | |
| Recursive models should be rejected with a clear error message since OpenAI doesn't support them. | |
| """ | |
| with pytest.raises(ValueError) as exc_info: | |
| sanitizer(RecursiveModel) | |
| # Check that the error message is helpful | |
| error_msg = str(exc_info.value) | |
| assert "Recursive model detected" in error_msg | |
| assert "RecursiveModel" in error_msg | |
| assert "OpenAI's structured outputs do NOT support recursive schemas" in error_msg | |
| assert "Consider redesigning your model" in error_msg | |
| def test_recursive_model_error_provides_alternatives(self, sanitizer): | |
| """ | |
| When rejecting recursive models, the error should suggest alternatives. | |
| """ | |
| with pytest.raises(ValueError) as exc_info: | |
| sanitizer(RecursiveModel) | |
| error_msg = str(exc_info.value) | |
| # Check that alternatives are suggested | |
| assert "flat structure with IDs" in error_msg | |
| assert "fixed depth" in error_msg | |
| assert "serialized nested data" in error_msg | |
| def test_deeply_nested_but_not_recursive_is_allowed(self, sanitizer): | |
| """ | |
| Models that are deeply nested but not recursive should work fine. | |
| """ | |
| # Create a deeply nested but non-recursive model | |
| class Level3(BaseModel): | |
| value: str | |
| class Level2(BaseModel): | |
| level3: Level3 | |
| data: str | |
| class Level1(BaseModel): | |
| level2: Level2 | |
| name: str | |
| # This should work fine - no recursion | |
| schema = sanitizer(Level1) | |
| # Verify the schema was generated | |
| assert schema["type"] == "object" | |
| assert "properties" in schema | |
| assert "level2" in schema["properties"] | |
| # Check that nesting is properly handled | |
| level2_schema = schema["properties"]["level2"] | |
| assert "properties" in level2_schema | |
| assert "level3" in level2_schema["properties"] | |
| def test_mutual_recursion_is_detected(self, sanitizer): | |
| """ | |
| Mutual recursion (A -> B -> A) should also be detected and rejected. | |
| """ | |
| class NodeA(BaseModel): | |
| name: str | |
| b_node: Optional["NodeB"] = None | |
| class NodeB(BaseModel): | |
| name: str | |
| a_node: Optional[NodeA] = None | |
| # Update forward refs | |
| NodeA.model_rebuild() | |
| NodeB.model_rebuild() | |
| # Both should be rejected | |
| with pytest.raises(ValueError) as exc_info: | |
| sanitizer(NodeA) | |
| assert "Recursive model detected" in str(exc_info.value) | |
| with pytest.raises(ValueError) as exc_info: | |
| sanitizer(NodeB) | |
| assert "Recursive model detected" in str(exc_info.value) | |
| def test_self_referential_list_is_detected(self, sanitizer): | |
| """ | |
| Self-referential models through lists should be detected. | |
| """ | |
| class TreeNode(BaseModel): | |
| value: int | |
| children: List["TreeNode"] = [] | |
| TreeNode.model_rebuild() | |
| with pytest.raises(ValueError) as exc_info: | |
| sanitizer(TreeNode) | |
| error_msg = str(exc_info.value) | |
| assert "Recursive model detected" in error_msg | |
| assert "TreeNode" in error_msg | |
| def test_optional_self_reference_is_detected(self, sanitizer): | |
| """ | |
| Even optional self-references should be detected. | |
| """ | |
| class LinkedListNode(BaseModel): | |
| data: str | |
| next: Optional["LinkedListNode"] = None | |
| LinkedListNode.model_rebuild() | |
| with pytest.raises(ValueError) as exc_info: | |
| sanitizer(LinkedListNode) | |
| error_msg = str(exc_info.value) | |
| assert "Recursive model detected" in error_msg | |
| assert "LinkedListNode" in error_msg | |
| def test_recursion_through_union_is_detected(self, sanitizer): | |
| """ | |
| Recursion through Union types should be detected. | |
| """ | |
| class UnionNode(BaseModel): | |
| value: str | |
| child: Union[str, "UnionNode", None] = None | |
| UnionNode.model_rebuild() | |
| with pytest.raises(ValueError) as exc_info: | |
| sanitizer(UnionNode) | |
| error_msg = str(exc_info.value) | |
| assert "Recursive model detected" in error_msg | |
| def test_non_recursive_complex_model_works(self, sanitizer): | |
| """ | |
| Complex models that look recursive but aren't should work. | |
| """ | |
| class Address(BaseModel): | |
| street: str | |
| city: str | |
| class Person(BaseModel): | |
| name: str | |
| address: Address | |
| class Company(BaseModel): | |
| name: str | |
| employees: List[Person] | |
| headquarters: Address | |
| # This should work - no actual recursion | |
| schema = sanitizer(Company) | |
| assert schema["type"] == "object" | |
| assert "employees" in schema["properties"] | |
| assert "headquarters" in schema["properties"] | |
| # Check that the list of persons is properly structured | |
| employees = schema["properties"]["employees"] | |
| assert employees["type"] == "array" | |
| assert "items" in employees | |
| assert employees["items"]["type"] == "object" | |
| assert "name" in employees["items"]["properties"] | |
| assert "address" in employees["items"]["properties"] | |
| # ============== 4. DISCRIMINATED UNIONS (Tagged Unions) ============== | |
| class Cat(BaseModel): | |
| pet_type: Literal["cat"] | |
| meows: int | |
| class Dog(BaseModel): | |
| pet_type: Literal["dog"] | |
| barks: int | |
| class DiscriminatedUnionModel(BaseModel): | |
| # Pydantic supports discriminated unions, OpenAI doesn't | |
| pet: Union[Cat, Dog] = Field(..., discriminator="pet_type") | |
| def test_discriminated_unions_collapse(sanitizer): | |
| """ | |
| Discriminated unions must collapse to single schema. | |
| """ | |
| schema = sanitizer(DiscriminatedUnionModel) | |
| # Should not have anyOf/oneOf | |
| pet = schema["properties"]["pet"] | |
| assert "anyOf" not in pet | |
| assert "oneOf" not in pet | |
| # Should collapse to first type or merge fields | |
| assert pet["type"] == "object" | |
| # ============== 5. CONSTRAINED TYPES ============== | |
| class ConstrainedModel(BaseModel): | |
| # String constraints | |
| username: constr(min_length=3, max_length=20, pattern=r"^[a-zA-Z0-9_]+$") | |
| email: constr(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$") | |
| # Numeric constraints | |
| age: conint(ge=0, le=150) | |
| percentage: confloat(ge=0.0, le=100.0) | |
| # Array constraints | |
| tags: List[str] = Field(..., min_items=1, max_items=10) | |
| def test_constraints_moved_to_description(sanitizer): | |
| """ | |
| All constraints should move to description, not schema. | |
| """ | |
| schema = sanitizer(ConstrainedModel) | |
| username = schema["properties"]["username"] | |
| # Pattern should be stripped and moved to description | |
| assert "pattern" not in username | |
| assert "minLength" not in username | |
| assert "maxLength" not in username | |
| # But info should be in description | |
| assert ( | |
| "min_length" in username.get("description", "").lower() | |
| or "constraints" in username.get("description", "").lower() | |
| ) | |
| age = schema["properties"]["age"] | |
| assert "minimum" not in age | |
| assert "maximum" not in age | |
| # ============== 6. DATETIME AND SPECIAL TYPES ============== | |
| class DateTimeModel(BaseModel): | |
| created_at: datetime | |
| birthday: Optional[date] = None | |
| price: Decimal | |
| metadata: Dict[str, Any] # Free-form dict | |
| raw_json: Optional[Dict[Any, Any]] = None | |
| def test_datetime_handling(sanitizer): | |
| """ | |
| DateTime/date should become strings with format info. | |
| """ | |
| schema = sanitizer(DateTimeModel) | |
| # DateTime becomes string | |
| created = schema["properties"]["created_at"] | |
| assert created["type"] == "string" | |
| # Format might be stripped, check description | |
| assert "format" not in created or created["format"] == "date-time" | |
| # Decimal becomes number | |
| price = schema["properties"]["price"] | |
| assert price["type"] == "number" | |
| # Free-form dicts should keep additionalProperties open | |
| metadata = schema["properties"]["metadata"] | |
| assert metadata["type"] == "object" | |
| # This one specifically should NOT have additionalProperties: false | |
| # since it's meant to be a free-form dict | |
| # ============== 7. EMPTY MODELS & EDGE CASES ============== | |
| class EmptyModel(BaseModel): | |
| """ | |
| Model with no fields. | |
| """ | |
| pass | |
| class AllOptionalModel(BaseModel): | |
| """ | |
| Every field is optional. | |
| """ | |
| a: Optional[str] = None | |
| b: Optional[int] = None | |
| c: Optional[bool] = None | |
| def test_empty_and_all_optional_models(sanitizer): | |
| """ | |
| Edge cases with empty or all-optional models. | |
| """ | |
| # Empty model | |
| empty_schema = sanitizer(EmptyModel) | |
| assert empty_schema["type"] == "object" | |
| assert empty_schema["properties"] == {} | |
| assert empty_schema.get("required", []) == [] | |
| # All optional -> all required+nullable | |
| optional_schema = sanitizer(AllOptionalModel) | |
| assert set(optional_schema["required"]) == {"a", "b", "c"} | |
| for field in ["a", "b", "c"]: | |
| assert "null" in optional_schema["properties"][field]["type"] | |
| # ============== 8. COMPLEX UNION COLLAPSES ============== | |
| class ComplexUnionModel(BaseModel): | |
| # Union of different types (not just enums) | |
| value: Union[str, int, List[str], Dict[str, int]] | |
| # Union with None in middle (not at end) | |
| weird: Union[str, None, int] = None | |
| # Multiple enums + types | |
| mixed: Union[MixedTypeEnum, IntEnum, str, int] = Field(...) | |
| def test_complex_union_handling(sanitizer): | |
| """ | |
| Complex unions should collapse sensibly. | |
| """ | |
| schema = sanitizer(ComplexUnionModel) | |
| # These should all collapse to first type or have a strategy | |
| value = schema["properties"]["value"] | |
| assert "anyOf" not in value | |
| assert "oneOf" not in value | |
| # Should pick a type (probably string as most general) | |
| assert value.get("type") is not None | |
| # ============== 9. DEFAULT VALUES EDGE CASES ============== | |
| class DefaultsModel(BaseModel): | |
| # Non-None defaults on optional fields | |
| name: Optional[str] = "Anonymous" | |
| count: Optional[int] = 0 | |
| # Callable defaults | |
| created: datetime = Field(default_factory=datetime.now) | |
| # Mutable defaults (bad practice but legal) | |
| tags: List[str] = Field(default_factory=list) | |
| def test_defaults_handling(sanitizer): | |
| """ | |
| Various default value patterns. | |
| """ | |
| schema = sanitizer(DefaultsModel) | |
| # Defaults should be stripped | |
| for prop in schema["properties"].values(): | |
| assert "default" not in prop | |
| # But fields should still be required (OpenAI style) | |
| assert "name" in schema["required"] | |
| assert "count" in schema["required"] | |
| # ============== 10. VALIDATION THAT BREAKS STRUCTURED OUTPUT ============== | |
| class ProblematicModel(BaseModel): | |
| # This has a validator that depends on another field | |
| password: str | |
| confirm_password: str | |
| # Pydantic v2 style validator | |
| @field_validator("confirm_password") | |
| @classmethod | |
| def passwords_match(cls, v, info): | |
| if "password" in info.data and v != info.data["password"]: | |
| raise ValueError("passwords do not match") | |
| return v | |
| # Computed field (Pydantic v2 style) | |
| @computed_field | |
| @property | |
| def password_strength(self) -> str: | |
| return "weak" if len(self.password) < 8 else "strong" | |
| # Alternative: Using model_validator for cross-field validation | |
| class ProblematicModelAlt(BaseModel): | |
| password: str | |
| confirm_password: str | |
| @model_validator(mode="after") | |
| def check_passwords_match(self): | |
| if self.password != self.confirm_password: | |
| raise ValueError("passwords do not match") | |
| return self | |
| def test_validators_dont_affect_schema(sanitizer): | |
| """ | |
| Validators and computed fields shouldn't affect schema. | |
| """ | |
| schema = sanitizer(ProblematicModel) | |
| # Should only have actual fields | |
| assert set(schema["properties"].keys()) == {"password", "confirm_password"} | |
| # Computed property should not appear | |
| assert "password_strength" not in schema["properties"] | |
| # ============== 11. PERFORMANCE & SCALE ============== | |
| def test_large_model_performance(sanitizer): | |
| """ | |
| Test with a model that has many fields. | |
| """ | |
| # Create a model with 100 fields dynamically | |
| fields = {f"field_{i}": (Optional[str], None) for i in range(100)} | |
| LargeModel = create_model("LargeModel", **fields) | |
| import time | |
| start = time.time() | |
| schema = sanitizer(LargeModel) | |
| duration = time.time() - start | |
| # Should complete in reasonable time | |
| assert duration < 1.0 # Less than 1 second for 100 fields | |
| # All fields should be required | |
| assert len(schema["required"]) == 100 | |
| # ============== 12. BROKEN MODELS THAT SHOULD ERROR ============== | |
| class UnsatisfiableModel(BaseModel): | |
| # Required field with None default (your preflight check catches this!) | |
| required_field: str = None # This is broken! | |
| # Union of incompatible types that can't be resolved | |
| impossible: Union[type, object, type(None)] | |
| def test_unsatisfiable_models_raise_errors(sanitizer): | |
| """ | |
| Models that can't be satisfied should raise clear errors. | |
| """ | |
| with pytest.raises(ValueError) as exc_info: | |
| sanitizer(UnsatisfiableModel, error_on_preflight=True) | |
| assert "unsatisfiable" in str(exc_info.value).lower() | |
| class LiteralTestModel(BaseModel): | |
| single_literal: Literal["v1"] | |
| multi_literal: Literal["production", "development"] | |
| optional_literal: Optional[Literal["active", "inactive", "pending"]] = None | |
| literal_with_field: Literal["test"] = Field(description="A literal field") | |
| def test_literals(): | |
| # Generate and inspect the schema | |
| model = LiteralTestModel | |
| schema = model.model_json_schema() | |
| print("Raw Pydantic schema for Literal types:") | |
| print(json.dumps(schema, indent=2)) | |
| # Check what's in each field | |
| for field_name, field_info in model.model_fields.items(): | |
| print(f"\n{field_name}:") | |
| field_schema = schema["properties"].get(field_name, {}) | |
| if "$ref" in field_schema: | |
| # Follow the reference | |
| ref_path = field_schema["$ref"].split("/")[-1] | |
| if "$defs" in schema and ref_path in schema["$defs"]: | |
| field_schema = schema["$defs"][ref_path] | |
| print(f" Referenced schema: {field_schema}") | |
| else: | |
| print(f" Direct schema: {field_schema}") | |
| # Check for enum | |
| if "enum" in field_schema: | |
| print(f" Has enum: {field_schema['enum']}") | |
| elif "const" in field_schema: | |
| print(f" Has const: {field_schema['const']}") | |
| elif "anyOf" in field_schema or "allOf" in field_schema: | |
| print(f" Has union: {field_schema}") | |
| else: | |
| print(f" Type only: {field_schema.get('type', 'no type')}") | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_EnumEdgeCases(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, EnumEdgeCases, prefer_json_schema=True) | |
| prompt = "Return JSON with:\n- status: 'text'\n- score: -1 (must be 1, -1, or 0)\n- ratio: 3.14159 (or null)\n" | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, EnumEdgeCases) | |
| assert result.status == MixedTypeEnum.STRING_VAL | |
| assert result.score in [IntEnum.POSITIVE, IntEnum.NEGATIVE, IntEnum.ZERO] | |
| assert result.ratio in [FloatEnum.PI, FloatEnum.E, None] | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_LiteralModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, LiteralModel, prefer_json_schema=True) | |
| prompt = "Return JSON with:\n- version: 'v1'\n- mode: 'production'\n- status: 'active' (or null)\n" | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, LiteralModel) | |
| assert result.version == "v1" | |
| assert result.mode in ["production", "development"] | |
| assert result.status in ["active", "inactive", "pending", None] | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_DeeplyNested(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, DeeplyNested, prefer_json_schema=True) | |
| prompt = "Return JSON with:\n- level1: {'key1': ['value1', null, 'value3'], 'key2': null} (or entirely null)\n" | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, DeeplyNested) | |
| assert isinstance(result.level1, (dict, type(None))) | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_DiscriminatedUnionModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, DiscriminatedUnionModel, prefer_json_schema=True) | |
| prompt = "Return JSON with:\n- pet: {pet_type: 'cat', meows: 5}\n" | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, DiscriminatedUnionModel) | |
| assert isinstance(result.pet, (Cat, Dog)) | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_ConstrainedModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, ConstrainedModel, prefer_json_schema=True) | |
| prompt = ( | |
| "Return JSON with:\n" | |
| "- username: 'john_doe123'\n" | |
| "- email: 'john@example.com'\n" | |
| "- age: 25\n" | |
| "- percentage: 75.5\n" | |
| "- tags: ['python', 'ai', 'testing']\n" | |
| ) | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, ConstrainedModel) | |
| assert isinstance(result.username, str) | |
| assert len(result.username) >= 3 and len(result.username) <= 20 | |
| assert isinstance(result.email, str) | |
| assert isinstance(result.age, int) | |
| assert 0 <= result.age <= 150 | |
| assert isinstance(result.percentage, float) | |
| assert 0.0 <= result.percentage <= 100.0 | |
| assert isinstance(result.tags, list) | |
| assert 1 <= len(result.tags) <= 10 | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_DateTimeModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, DateTimeModel, prefer_json_schema=True) | |
| prompt = ( | |
| "Return JSON with:\n" | |
| "- created_at: '2024-01-15T10:30:00Z'\n" | |
| "- birthday: '1990-05-20' (or null)\n" | |
| "- price: 99.99\n" | |
| "- metadata: {'version': '1.0', 'author': 'test'}\n" | |
| "- raw_json: {'any': 'data', 'nested': {'key': 'value'}} (or null)\n" | |
| ) | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, DateTimeModel) | |
| assert isinstance(result.created_at, datetime) | |
| assert isinstance(result.birthday, (date, type(None))) | |
| assert isinstance(result.price, Decimal) | |
| assert isinstance(result.metadata, dict) | |
| assert isinstance(result.raw_json, (dict, type(None))) | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_EmptyModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, EmptyModel, prefer_json_schema=True) | |
| prompt = "Return an empty JSON object: {}" | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, EmptyModel) | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_AllOptionalModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, AllOptionalModel, prefer_json_schema=True) | |
| prompt = "Return JSON with:\n- a: 'text' (or null)\n- b: 42 (or null)\n- c: true (or null)\n" | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, AllOptionalModel) | |
| assert isinstance(result.a, (str, type(None))) | |
| assert isinstance(result.b, (int, type(None))) | |
| assert isinstance(result.c, (bool, type(None))) | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_ComplexUnionModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, ComplexUnionModel, prefer_json_schema=True) | |
| prompt = ( | |
| "Return JSON with:\n- value: 'simple string'\n- weird: 42 (or null)\n- mixed: 'text' (from MixedTypeEnum)\n" | |
| ) | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, ComplexUnionModel) | |
| assert result.value is not None | |
| assert isinstance(result.weird, (str, int, type(None))) | |
| assert result.mixed is not None | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_DefaultsModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, DefaultsModel, prefer_json_schema=True) | |
| prompt = ( | |
| "Return JSON with:\n" | |
| "- name: 'Alice' (or null)\n" | |
| "- count: 10 (or null)\n" | |
| "- created: '2024-01-15T12:00:00Z'\n" | |
| "- tags: ['tag1', 'tag2']\n" | |
| ) | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, DefaultsModel) | |
| assert isinstance(result.name, (str, type(None))) | |
| assert isinstance(result.count, (int, type(None))) | |
| assert isinstance(result.created, datetime) | |
| assert isinstance(result.tags, list) | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_ProblematicModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, ProblematicModel, prefer_json_schema=True) | |
| prompt = "Return JSON with:\n- password: 'secret123'\n- confirm_password: 'secret123'\n" | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, ProblematicModel) | |
| assert isinstance(result.password, str) | |
| assert isinstance(result.confirm_password, str) | |
| assert result.password == result.confirm_password | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_ProblematicModelAlt(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, ProblematicModelAlt, prefer_json_schema=True) | |
| prompt = "Return JSON with:\n- password: 'mypass456'\n- confirm_password: 'mypass456'\n" | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, ProblematicModelAlt) | |
| assert isinstance(result.password, str) | |
| assert isinstance(result.confirm_password, str) | |
| assert result.password == result.confirm_password | |
| @pytest.mark.integration_test | |
| @pytest.mark.asyncio | |
| @pytest.mark.parametrize("model_name", INTEGRATION_MODELS) | |
| @pytest.mark.timeout(10) | |
| @need_openai | |
| async def test_real_openai_invocation_LiteralTestModel(model_name): | |
| model, _ = get_langchain_model_and_token_counter(model_name) | |
| chain = safe_with_structured_output(model, LiteralTestModel, prefer_json_schema=True) | |
| prompt = ( | |
| "Return JSON with:\n" | |
| "- single_literal: 'v1'\n" | |
| "- multi_literal: 'development'\n" | |
| "- optional_literal: 'inactive' (or null)\n" | |
| "- literal_with_field: 'test'\n" | |
| ) | |
| result = await chain.ainvoke(prompt) | |
| assert isinstance(result, LiteralTestModel) | |
| assert result.single_literal == "v1" | |
| assert result.multi_literal in ["production", "development"] | |
| assert result.optional_literal in ["active", "inactive", "pending", None] | |
| assert result.literal_with_field == "test" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment