Last active
December 4, 2025 19:56
-
-
Save ykdojo/5fe97bc6514342988cceb54cd68330ed to your computer and use it in GitHub Desktop.
Daft Colab Pydantic pickle test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Colab Pydantic Pickle Test""" | |
| from __future__ import annotations | |
| import io | |
| import types | |
| from typing import Union, get_args, get_origin | |
| from pydantic import BaseModel, create_model | |
| # The fix functions | |
| def clean_pydantic_model(model_cls, _cleaned_cache=None): | |
| if _cleaned_cache is None: | |
| _cleaned_cache = {} | |
| if model_cls in _cleaned_cache: | |
| return _cleaned_cache[model_cls] | |
| referenced_models = [] | |
| for field_info in model_cls.model_fields.values(): | |
| annotation = field_info.annotation | |
| origin = get_origin(annotation) | |
| if origin is not None: | |
| for arg in get_args(annotation): | |
| if isinstance(arg, type) and issubclass(arg, BaseModel): | |
| if arg is not model_cls and arg not in _cleaned_cache and arg not in referenced_models: | |
| referenced_models.append(arg) | |
| elif isinstance(annotation, type) and issubclass(annotation, BaseModel): | |
| if annotation is not model_cls and annotation not in _cleaned_cache and annotation not in referenced_models: | |
| referenced_models.append(annotation) | |
| for ref_model in referenced_models: | |
| clean_pydantic_model(ref_model, _cleaned_cache) | |
| field_definitions = {} | |
| for field_name, field_info in model_cls.model_fields.items(): | |
| annotation = field_info.annotation | |
| origin = get_origin(annotation) | |
| if origin is not None: | |
| new_args = [] | |
| for arg in get_args(annotation): | |
| if isinstance(arg, type) and issubclass(arg, BaseModel) and arg in _cleaned_cache: | |
| new_args.append(_cleaned_cache[arg]) | |
| else: | |
| new_args.append(arg) | |
| if origin is types.UnionType or origin is Union: | |
| annotation = Union[tuple(new_args)] | |
| elif len(new_args) > 1: | |
| annotation = origin[tuple(new_args)] | |
| else: | |
| annotation = origin[new_args[0]] | |
| elif isinstance(annotation, type) and issubclass(annotation, BaseModel) and annotation in _cleaned_cache: | |
| annotation = _cleaned_cache[annotation] | |
| default = ... if field_info.is_required() else field_info.default | |
| field_definitions[field_name] = (annotation, default) | |
| cleaned = create_model(model_cls.__name__, **field_definitions) | |
| _cleaned_cache[model_cls] = cleaned | |
| return cleaned | |
| def colab_safe_dumps(obj): | |
| from daft.pickle.cloudpickle import Pickler as CloudPickler | |
| from daft.pickle.cloudpickle import dumps as cloudpickle_dumps | |
| from daft.pickle.cloudpickle import loads as cloudpickle_loads | |
| class ColabSafePickler(CloudPickler): | |
| def reducer_override(self, obj): | |
| if isinstance(obj, type) and issubclass(obj, BaseModel) and obj is not BaseModel: | |
| cleaned = clean_pydantic_model(obj) | |
| cleaned_bytes = cloudpickle_dumps(cleaned) | |
| return (cloudpickle_loads, (cleaned_bytes,)) | |
| return NotImplemented | |
| buffer = io.BytesIO() | |
| ColabSafePickler(buffer).dump(obj) | |
| return buffer.getvalue() | |
| # Monkey-patch daft.pickle.dumps to use our fix | |
| import daft.pickle | |
| daft.pickle.dumps = colab_safe_dumps | |
| # Define test model | |
| class Result(BaseModel): | |
| answer: bool | |
| # Test 1: Direct pickle of model class | |
| print("Test 1: daft.pickle.dumps(Result) - model class") | |
| try: | |
| data = daft.pickle.dumps(Result) | |
| print(" ✓ SUCCESS") | |
| except Exception as e: | |
| print(f" ✗ FAILED: {e}") | |
| # Test 2: Direct pickle of model instance | |
| print("Test 2: daft.pickle.dumps(Result(answer=True)) - model instance") | |
| try: | |
| data = daft.pickle.dumps(Result(answer=True)) | |
| print(" ✓ SUCCESS") | |
| except Exception as e: | |
| print(f" ✗ FAILED: {e}") | |
| # Test 3: Daft UDF with Pydantic return type | |
| print("Test 3: @daft.func with Pydantic return type") | |
| import daft | |
| try: | |
| @daft.func | |
| def analyze(text: str) -> Result: | |
| return Result(answer=text == "hello") | |
| df = daft.from_pydict({"text": ["hello", "world"]}) | |
| df = df.with_column("analysis", analyze(df["text"])) | |
| df.show(2) | |
| print(" ✓ SUCCESS") | |
| except Exception as e: | |
| if "pickle" in str(e).lower() or "handle" in str(e).lower(): | |
| print(f" ✗ FAILED (pickle error): {e}") | |
| else: | |
| print(f" ✓ SUCCESS (serialization worked, other error: {type(e).__name__})") | |
| print("\nDone!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment