Skip to content

Instantly share code, notes, and snippets.

@ykdojo
Created December 4, 2025 19:59
Show Gist options
  • Select an option

  • Save ykdojo/d54ac6695a18a4350e4437c680628ea9 to your computer and use it in GitHub Desktop.

Select an option

Save ykdojo/d54ac6695a18a4350e4437c680628ea9 to your computer and use it in GitHub Desktop.
Colab Pydantic pickle test - WITH fix
"""Colab Pydantic Pickle Test"""
from __future__ import annotations
import io
import types
from typing import Union, get_args, get_origin
from pydantic import BaseModel, create_model
# The fix functions
def clean_pydantic_model(model_cls, _cleaned_cache=None):
if _cleaned_cache is None:
_cleaned_cache = {}
if model_cls in _cleaned_cache:
return _cleaned_cache[model_cls]
referenced_models = []
for field_info in model_cls.model_fields.values():
annotation = field_info.annotation
origin = get_origin(annotation)
if origin is not None:
for arg in get_args(annotation):
if isinstance(arg, type) and issubclass(arg, BaseModel):
if arg is not model_cls and arg not in _cleaned_cache and arg not in referenced_models:
referenced_models.append(arg)
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
if annotation is not model_cls and annotation not in _cleaned_cache and annotation not in referenced_models:
referenced_models.append(annotation)
for ref_model in referenced_models:
clean_pydantic_model(ref_model, _cleaned_cache)
field_definitions = {}
for field_name, field_info in model_cls.model_fields.items():
annotation = field_info.annotation
origin = get_origin(annotation)
if origin is not None:
new_args = []
for arg in get_args(annotation):
if isinstance(arg, type) and issubclass(arg, BaseModel) and arg in _cleaned_cache:
new_args.append(_cleaned_cache[arg])
else:
new_args.append(arg)
if origin is types.UnionType or origin is Union:
annotation = Union[tuple(new_args)]
elif len(new_args) > 1:
annotation = origin[tuple(new_args)]
else:
annotation = origin[new_args[0]]
elif isinstance(annotation, type) and issubclass(annotation, BaseModel) and annotation in _cleaned_cache:
annotation = _cleaned_cache[annotation]
default = ... if field_info.is_required() else field_info.default
field_definitions[field_name] = (annotation, default)
cleaned = create_model(model_cls.__name__, **field_definitions)
_cleaned_cache[model_cls] = cleaned
return cleaned
def colab_safe_dumps(obj):
from daft.pickle.cloudpickle import Pickler as CloudPickler
from daft.pickle.cloudpickle import dumps as cloudpickle_dumps
from daft.pickle.cloudpickle import loads as cloudpickle_loads
class ColabSafePickler(CloudPickler):
def reducer_override(self, obj):
if isinstance(obj, type) and issubclass(obj, BaseModel) and obj is not BaseModel:
cleaned = clean_pydantic_model(obj)
cleaned_bytes = cloudpickle_dumps(cleaned)
return (cloudpickle_loads, (cleaned_bytes,))
return NotImplemented
buffer = io.BytesIO()
ColabSafePickler(buffer).dump(obj)
return buffer.getvalue()
# Monkey-patch daft.pickle.dumps to use our fix
import daft.pickle
daft.pickle.dumps = colab_safe_dumps
# Define test model
class Result(BaseModel):
answer: bool
# Test 1: Direct pickle of model class
print("Test 1: daft.pickle.dumps(Result) - model class")
try:
data = daft.pickle.dumps(Result)
print(" ✓ SUCCESS")
except Exception as e:
print(f" ✗ FAILED: {e}")
# Test 2: Direct pickle of model instance
print("Test 2: daft.pickle.dumps(Result(answer=True)) - model instance")
try:
data = daft.pickle.dumps(Result(answer=True))
print(" ✓ SUCCESS")
except Exception as e:
print(f" ✗ FAILED: {e}")
# Test 3: Daft UDF with Pydantic return type
print("Test 3: @daft.func with Pydantic return type")
import daft
try:
@daft.func
def analyze(text: str) -> Result:
return Result(answer=text == "hello")
df = daft.from_pydict({"text": ["hello", "world"]})
df = df.with_column("analysis", analyze(df["text"]))
df.show(2)
print(" ✓ SUCCESS")
except Exception as e:
if "pickle" in str(e).lower() or "handle" in str(e).lower():
print(f" ✗ FAILED (pickle error): {e}")
else:
print(f" ✓ SUCCESS (serialization worked, other error: {type(e).__name__})")
print("\nDone!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment