Skip to content

Instantly share code, notes, and snippets.

@eliaskanelis
Created November 19, 2025 02:25
Show Gist options
  • Select an option

  • Save eliaskanelis/284472ac655b663572eeaee7ae55f7c1 to your computer and use it in GitHub Desktop.

Select an option

Save eliaskanelis/284472ac655b663572eeaee7ae55f7c1 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from typing import Any, TypeVar, get_type_hints
from datetime import datetime, timezone
from pydantic import BaseModel, Field, EmailStr, field_validator, create_model
import json
from typing import cast
from rich.console import Console
from rich.syntax import Syntax
from rich.pretty import Pretty
from difflib import unified_diff
from rich.traceback import install
from rich.prompt import Prompt
# --------------------------------------------------------------------
_T = TypeVar("_T", bound=BaseModel)
def make_partial(model: type[BaseModel]) -> type[BaseModel]:
"""Create a partial Pydantic model where all fields are optional."""
hints = get_type_hints(model)
field_defs: dict[str, tuple[Any, Any]] = {}
for name, typ in hints.items():
field_defs[name] = (typ | None, None)
Partial = create_model( # type: ignore[arg-type]
f"Partial{model.__name__}",
**cast(dict[str, Any], field_defs),
)
return Partial
class ModelOps:
"""Unified operations on Pydantic models."""
def __init__(self, obj: _T):
self.obj: _T = obj
def __str__(self) -> str:
console = Console(record=True)
with console.capture() as capture:
console.print(
Syntax(
self.obj.model_dump_json(indent=4),
"json",
theme="monokai",
line_numbers=False,
word_wrap=False,
)
)
return capture.get()
def __repr__(self) -> str:
console = Console(record=True)
with console.capture() as capture:
console.print(
Syntax(
self.obj.model_dump_json(indent=4),
"json",
theme="monokai",
line_numbers=False,
word_wrap=False,
)
)
return capture.get()
def to_dict(self) -> dict[str, Any]:
"""Convert model to dictionary."""
return self.obj.model_dump()
def to_json(self) -> str:
"""Convert model to json string representation."""
return self.obj.model_dump_json(indent=4)
def fields(self) -> list[str]:
"""Get a list with all the fields."""
return list(self.obj.__class__.model_fields.keys())
def schema(self) -> dict[str, Any]:
"""Return the Pydantic schema for introspection."""
return self.obj.__class__.model_json_schema()
def update(self, **kwargs: Any) -> BaseModel:
"""Create a new instance with updated fields."""
data = self.to_dict()
data.update(kwargs)
return self.obj.__class__(**data)
def diff(self, other: _T) -> dict[str, tuple]:
"""Return fields that differ: {field: (old, new)}."""
o1 = self.to_dict()
o2 = ModelOps(other).to_dict()
out = {}
for k in set(o1) | set(o2):
if o1.get(k) != o2.get(k):
out[k] = (o1.get(k), o2.get(k))
return out
def model_diff(self, b: _T) -> BaseModel:
if type(a) is not type(b):
raise TypeError("Models must be the same type")
Partial = make_partial(type(a))
diff_data = {}
for field in self.obj.__class__.model_fields:
av = getattr(self.obj, field)
bv = getattr(b, field)
if av != bv:
diff_data[field] = bv
return Partial(**diff_data)
def udiff(self, b: BaseModel) -> str:
a_json = self.obj.model_dump_json(indent=4).splitlines(keepends=True)
b_json = b.model_dump_json(indent=4).splitlines(keepends=True)
diff_lines = list(unified_diff(a_json, b_json, fromfile="a", tofile="b"))
# Render with Rich Syntax (diff language)
# syntax = Syntax("".join(diff_lines), "diff", theme="monokai", line_numbers=False)
# console = Console(record=True)
# console.print(syntax)
return "".join(diff_lines)
def merge(self, other: _T) -> _T:
"""Shallow merge with another instance."""
d1 = self.to_dict()
d2 = ModelOps(other).to_dict()
merged = {**d1, **d2}
return cast(_T, self.obj.__class__(**merged))
# return type(self.obj)(**merged)
@staticmethod
def model_to_json(obj: BaseModel) -> str:
"""Dump model as pretty JSON string."""
return json.dumps(obj.model_dump(), indent=4, sort_keys=True)
# --------------------------------------------------------------------
def is_user_defined_instance(obj) -> bool:
return not isinstance(obj, type) and type(obj).__module__ != "builtins"
def is_unified_diff(text: Any) -> bool:
"""Return True if the string looks like a unified diff."""
if not isinstance(text, str):
return False
lines = text.splitlines()
if len(lines) < 3:
return False
# check first two lines for --- and +++
if not (lines[0].startswith("--- ") and lines[1].startswith("+++ ")):
return False
# check for at least one hunk header
return any(line.startswith("@@ ") for line in lines[2:])
class UI:
console = Console(record=True)
THEME = "monokai"
@classmethod
def prompt(cls, model: Any) -> Any:
cls = type(model)
user_data = Prompt.ask(f"[{cls.__name__}] {model.__name__}", choices=["Paul", "Jessica", "Duncan"], default="Paul")
return cls(user_data)
@classmethod
def print(cls, data: Any) -> None:
if isinstance(data, dict):
json_str = json.dumps(data, indent=4)
syntax = Syntax(json_str, "json", theme=cls.THEME, line_numbers=False)
cls.console.print(syntax)
elif isinstance(data, str):
if is_unified_diff(data):
syntax = Syntax(data, "diff", theme=cls.THEME, line_numbers=False)
cls.console.print(syntax)
else:
cls.console.print(data)
elif isinstance(data, BaseModel):
cls.console.print(
Syntax(
data.model_dump_json(indent=4),
"json",
theme=cls.THEME,
line_numbers=False,
word_wrap=False,
)
)
elif isinstance(data, type):
from rich import inspect
inspect(
data,
methods=True,
dunder=False,
all=False,
private=False,
docs=True,
value=False,
console=cls.console,
)
elif is_user_defined_instance(data):
print(data)
else:
cls.console.print(Pretty(vars(data)))
cls.console.print(Pretty(data))
# --------------------------------------------------------------------
class User(BaseModel):
"""User."""
id: int = Field(..., ge=1)
username: str = Field(..., min_length=3, max_length=32)
email: EmailStr
tags: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
@field_validator("username")
@classmethod
def no_weird_usernames(cls, v: str) -> str:
if " " in v:
raise ValueError("username cannot contain spaces")
return v
@field_validator("tags")
@classmethod
def clean_tags(cls, v: list[str]) -> list[str]:
out = []
for tag in v:
t = tag.strip().lower()
if not t:
continue
out.append(t)
return out
# Install Rich traceback globally
install(theme="monokai", show_locals=True, width=120)
def divide(a, b):
return a / b
# divide(10, 0)
z = make_partial(User)()
UI.print(type(z))
UI.print(z)
a = User(id=1, username="Elias", email="user@yahoo.gr", tags=["go", "stop"])
b = a.model_copy()
b.id = 2
b.email = "user@google.com"
a_ops = ModelOps(obj=a)
# UI.print(type(a_ops))
UI.print(a_ops)
# UI.print(a_ops.fields())
# UI.print(a_ops.to_json())
# UI.print(a_ops.schema())
# UI.print(a_ops.diff(b))
UI.print(a_ops.model_diff(b))
# UI.print(a_ops.udiff(b))
# UI.print(a_ops.merge(b))
UI.prompt(User)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment