Created
August 5, 2025 09:04
-
-
Save leontrolski/a271809369d437958c5dc58e38e405a0 to your computer and use it in GitHub Desktop.
Structural typechecking- adjust to your domain
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from functools import cache | |
| from types import UnionType | |
| from typing import Annotated, Any, Literal, Union, get_args, get_origin | |
| import pydantic | |
| from pydantic.fields import PydanticUndefined as MISSING # type: ignore[attr-defined] | |
| def normalize(t: Any) -> Any: | |
| if get_origin(t) is Annotated: | |
| t, *_ = get_args(t) | |
| return normalize(t) | |
| return t | |
| @cache | |
| def issubtype(a: Any, b: Any) -> bool: | |
| a, b = normalize(a), normalize(b) | |
| origin_a, args_a, origin_b, args_b = get_origin(a), get_args(a), get_origin(b), get_args(b) | |
| # Convert Literal[a, b] to Union[Literal[a], Literal[b]] | |
| if origin_a is Literal and len(args_a) > 1: | |
| return issubtype(Union[*(Literal[x] for x in args_a)], b) | |
| if origin_b is Literal and len(args_b) > 1: | |
| return issubtype(a, Union[*(Literal[y] for y in args_b)]) | |
| # Handle Unions | |
| if origin_a in {Union, UnionType}: | |
| return all(issubtype(x, b) for x in args_a) | |
| if origin_b in {Union, UnionType}: | |
| return any(issubtype(a, y) for y in args_b) | |
| # Handle Literals | |
| if origin_b is Literal: | |
| return a == b # type: ignore[no-any-return] | |
| if origin_a is Literal: | |
| return issubtype(type(args_a[0]), b) # type: ignore[arg-type] | |
| # Handle G[T, U, ...] | |
| if origin_a: | |
| return bool( | |
| origin_b | |
| and issubtype(origin_a, origin_b) | |
| and len(args_a) == len(args_b) | |
| and all(issubtype(x, y) for x, y in zip(args_a, args_b)) | |
| ) | |
| # Handle pydantic structurally based on key, we ignore any keys with a default | |
| if issubclass(a, pydantic.BaseModel): | |
| if not issubclass(b, pydantic.BaseModel): | |
| return False | |
| fields_a = { | |
| k: f.annotation | |
| for k, f in a.model_fields.items() | |
| if f.default is MISSING and f.default_factory is None | |
| } | |
| fields_b = { | |
| k: f.annotation | |
| for k, f in b.model_fields.items() | |
| if f.default is MISSING and f.default_factory is None | |
| } | |
| if set(fields_a) - set(fields_b): | |
| return False | |
| return all(issubtype(fields_a[k], fields_b[k]) for k in fields_a) | |
| return issubclass(a, b) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment