Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shrekris-anyscale/1d82c20d4aafd508a25df4d67f19d8fd to your computer and use it in GitHub Desktop.
Save shrekris-anyscale/1d82c20d4aafd508a25df4d67f19d8fd to your computer and use it in GitHub Desktop.
"""Workaround to serialize and deserialize Pydantic 2.x models w/ cloudpickle."""
import pydantic
from pydantic import BaseModel
print(f"Pydantic version: {pydantic.__version__}")
import ray.cloudpickle as cloudpickle
original_reducer = cloudpickle.CloudPickler.reducer_override
def pydantic_reducer(self, obj):
from pydantic import BaseModel
from pydantic._internal._model_construction import ModelMetaclass
if type(obj) == ModelMetaclass:
obj.__pydantic_serializer__ = None
reducer_result = original_reducer(self, obj)
if reducer_result is NotImplemented:
assert obj is BaseModel
return reducer_result
# At this point, reducer_result confroms to the __reduce__ interface's
# output. See https://docs.python.org/3/library/pickle.html#object.__reduce__
original_set_state_override = reducer_result[5]
def pydantic_set_state_override(obj, state):
from pydantic._internal._config import ConfigWrapper
from pydantic_core._pydantic_core import SchemaSerializer
from pydantic._internal._core_utils import inline_schema_defs
stateful_obj = original_set_state_override(obj, state)
schema = inline_schema_defs(stateful_obj.__pydantic_core_schema__)
config = ConfigWrapper(config=stateful_obj.model_config).core_config(stateful_obj)
stateful_obj.__pydantic_serializer__ = SchemaSerializer(schema=schema, config=config)
return stateful_obj
overridden_reducer_result = tuple(list(reducer_result)[:-1] + [pydantic_set_state_override])
return overridden_reducer_result
else:
return original_reducer(self, obj)
# Comment out this line to use the original cloudpickle logic.
cloudpickle.CloudPickler.reducer_override = pydantic_reducer
def run_experiment():
# Create Pydantic models inside run_experiment's local scope. These
# models cannot be serialized with vanilla pickle. They need
# cloudpickle.
class NestedModel(BaseModel):
val: int
class ContainerModel(BaseModel):
name: str
nest: NestedModel
original_n = NestedModel(val=5)
original_c = ContainerModel(name="hi", nest=original_n)
ser = cloudpickle.dumps(NestedModel)
deser = cloudpickle.loads(ser)
new_n = deser(val=5)
print(f"NestedModel serde result: {new_n == original_n}")
print(f"NestedModel serialized: {new_n.model_dump()}")
print(f"NestedModel json-serialized: {new_n.model_dump_json()}")
ser = cloudpickle.dumps(ContainerModel)
deser = cloudpickle.loads(ser)
new_c = deser(name="hi", nest=new_n)
print(f"ContainerModel serde result: {new_c == original_c}")
print(f"ContainerModel serialized: {new_c.model_dump()}")
print(f"ContainerModel json-serialized: {new_c.model_dump_json()}")
run_experiment()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment