-
-
Save fg91/5874960833ebd7068925be31724ceb90 to your computer and use it in GitHub Desktop.
Pydantic base model type transformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flytekit import task, workflow | |
from flytekit.core.type_engine import ( | |
TypeEngine, | |
TypeTransformer, | |
TypeTransformerFailedError, | |
) | |
from pydantic import BaseModel | |
from typing import Optional, Type, Union | |
from flytekit import FlyteContext | |
from flytekit.extend import TypeEngine, TypeTransformer | |
from flytekit.models.literals import Literal, Scalar | |
from flytekit.models.types import LiteralType, SimpleType | |
from google.protobuf.json_format import MessageToDict | |
from google.protobuf.struct_pb2 import Struct | |
# --- | |
class BaseModelTransformer(TypeTransformer[BaseModel]): | |
_TYPE_INFO = LiteralType(simple=SimpleType.STRUCT) | |
def __init__(self): | |
"""Construct BaseModelTransformer.""" | |
super().__init__(name="basemodel-transform", t=BaseModel) | |
def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: | |
return LiteralType(simple=SimpleType.STRUCT) | |
def to_literal( | |
self, | |
ctx: FlyteContext, | |
python_val: BaseModel, | |
python_type: Type[BaseModel], | |
expected: LiteralType, | |
) -> Literal: | |
"""This method is used to convert from given python type object ``BaseModel`` to the Literal representation.""" | |
s = Struct() | |
s.update({"schema": python_val.schema(), "data": python_val.dict()}) | |
return Literal(scalar=Scalar(generic=s)) | |
def to_python_value( | |
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel] | |
) -> BaseModel: | |
"""In this function we want to be able to re-hydrate the custom object from Flyte Literal value.""" | |
base_model = MessageToDict(lv.scalar.generic) | |
schema = base_model["schema"] | |
data = base_model["data"] | |
if (expected_schema := expected_python_type.schema()) != schema: | |
raise TypeTransformerFailedError( | |
f"The schema `{expected_schema}` of the expeted python type {expected_python_type} is not equal to the received schema `{schema}`." | |
) | |
return expected_python_type().parse_obj(data) | |
TypeEngine.register(BaseModelTransformer()) | |
# --- | |
class ModelConfig(BaseModel): | |
a: int = 1 | |
b: float = 2.0 | |
c: str = "foo" | |
class Config(BaseModel): | |
model_config: Optional[Union[dict[str, ModelConfig], ModelConfig]] = ModelConfig() # <- Something that makes trouble with the data class transformer | |
def update(self): | |
print("Updating config in an imaginary way") | |
@task | |
def train(cfg: Config) -> Config: | |
print(f"Training with config {cfg}") | |
cfg.update() | |
return cfg | |
@workflow | |
def wf(cfg: Config) -> Config: | |
return train(cfg=cfg) | |
if __name__ == "__main__": | |
in_cfg = Config(model_config={"foo": ModelConfig()}) | |
out_cfg = wf(cfg=in_cfg) | |
assert in_cfg == out_cfg |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Considerations:
If one creates an inheriting class ...
... and uses this in the type hint in the task (but nowhere else) ...
... the
TypeTransformerFailedError
in line 56 is raised because the schemas don't match.It would work if we don't compare the schemas but within the task,
cfg
would have thechild_a
attribute.I feel that not allowing inheritance and being strict about the schema makes sense here.
IIUC this would be the same behaviour as in the SimpleTransformer, where the types have to be exact as well.