Created
July 20, 2023 03:47
-
-
Save shrekris-anyscale/1d82c20d4aafd508a25df4d67f19d8fd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Workaround to serialize and deserialize Pydantic 2.x models w/ cloudpickle.""" | |
import pydantic | |
from pydantic import BaseModel | |
print(f"Pydantic version: {pydantic.__version__}") | |
import ray.cloudpickle as cloudpickle | |
original_reducer = cloudpickle.CloudPickler.reducer_override | |
def pydantic_reducer(self, obj): | |
from pydantic import BaseModel | |
from pydantic._internal._model_construction import ModelMetaclass | |
if type(obj) == ModelMetaclass: | |
obj.__pydantic_serializer__ = None | |
reducer_result = original_reducer(self, obj) | |
if reducer_result is NotImplemented: | |
assert obj is BaseModel | |
return reducer_result | |
# At this point, reducer_result confroms to the __reduce__ interface's | |
# output. See https://docs.python.org/3/library/pickle.html#object.__reduce__ | |
original_set_state_override = reducer_result[5] | |
def pydantic_set_state_override(obj, state): | |
from pydantic._internal._config import ConfigWrapper | |
from pydantic_core._pydantic_core import SchemaSerializer | |
from pydantic._internal._core_utils import inline_schema_defs | |
stateful_obj = original_set_state_override(obj, state) | |
schema = inline_schema_defs(stateful_obj.__pydantic_core_schema__) | |
config = ConfigWrapper(config=stateful_obj.model_config).core_config(stateful_obj) | |
stateful_obj.__pydantic_serializer__ = SchemaSerializer(schema=schema, config=config) | |
return stateful_obj | |
overridden_reducer_result = tuple(list(reducer_result)[:-1] + [pydantic_set_state_override]) | |
return overridden_reducer_result | |
else: | |
return original_reducer(self, obj) | |
# Comment out this line to use the original cloudpickle logic. | |
cloudpickle.CloudPickler.reducer_override = pydantic_reducer | |
def run_experiment(): | |
# Create Pydantic models inside run_experiment's local scope. These | |
# models cannot be serialized with vanilla pickle. They need | |
# cloudpickle. | |
class NestedModel(BaseModel): | |
val: int | |
class ContainerModel(BaseModel): | |
name: str | |
nest: NestedModel | |
original_n = NestedModel(val=5) | |
original_c = ContainerModel(name="hi", nest=original_n) | |
ser = cloudpickle.dumps(NestedModel) | |
deser = cloudpickle.loads(ser) | |
new_n = deser(val=5) | |
print(f"NestedModel serde result: {new_n == original_n}") | |
print(f"NestedModel serialized: {new_n.model_dump()}") | |
print(f"NestedModel json-serialized: {new_n.model_dump_json()}") | |
ser = cloudpickle.dumps(ContainerModel) | |
deser = cloudpickle.loads(ser) | |
new_c = deser(name="hi", nest=new_n) | |
print(f"ContainerModel serde result: {new_c == original_c}") | |
print(f"ContainerModel serialized: {new_c.model_dump()}") | |
print(f"ContainerModel json-serialized: {new_c.model_dump_json()}") | |
run_experiment() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment