Skip to content

Instantly share code, notes, and snippets.

@fg91
Created March 3, 2023 19:13
Show Gist options
  • Save fg91/5874960833ebd7068925be31724ceb90 to your computer and use it in GitHub Desktop.
Save fg91/5874960833ebd7068925be31724ceb90 to your computer and use it in GitHub Desktop.
Pydantic base model type transformer
from flytekit import task, workflow
from flytekit.core.type_engine import (
TypeEngine,
TypeTransformer,
TypeTransformerFailedError,
)
from pydantic import BaseModel
from typing import Optional, Type, Union
from flytekit import FlyteContext
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar
from flytekit.models.types import LiteralType, SimpleType
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct
# ---
class BaseModelTransformer(TypeTransformer[BaseModel]):
_TYPE_INFO = LiteralType(simple=SimpleType.STRUCT)
def __init__(self):
"""Construct BaseModelTransformer."""
super().__init__(name="basemodel-transform", t=BaseModel)
def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT)
def to_literal(
self,
ctx: FlyteContext,
python_val: BaseModel,
python_type: Type[BaseModel],
expected: LiteralType,
) -> Literal:
"""This method is used to convert from given python type object ``BaseModel`` to the Literal representation."""
s = Struct()
s.update({"schema": python_val.schema(), "data": python_val.dict()})
return Literal(scalar=Scalar(generic=s))
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]
) -> BaseModel:
"""In this function we want to be able to re-hydrate the custom object from Flyte Literal value."""
base_model = MessageToDict(lv.scalar.generic)
schema = base_model["schema"]
data = base_model["data"]
if (expected_schema := expected_python_type.schema()) != schema:
raise TypeTransformerFailedError(
f"The schema `{expected_schema}` of the expeted python type {expected_python_type} is not equal to the received schema `{schema}`."
)
return expected_python_type().parse_obj(data)
TypeEngine.register(BaseModelTransformer())
# ---
class ModelConfig(BaseModel):
a: int = 1
b: float = 2.0
c: str = "foo"
class Config(BaseModel):
model_config: Optional[Union[dict[str, ModelConfig], ModelConfig]] = ModelConfig() # <- Something that makes trouble with the data class transformer
def update(self):
print("Updating config in an imaginary way")
@task
def train(cfg: Config) -> Config:
print(f"Training with config {cfg}")
cfg.update()
return cfg
@workflow
def wf(cfg: Config) -> Config:
return train(cfg=cfg)
if __name__ == "__main__":
in_cfg = Config(model_config={"foo": ModelConfig()})
out_cfg = wf(cfg=in_cfg)
assert in_cfg == out_cfg
@fg91
Copy link
Author

fg91 commented Mar 3, 2023

Considerations:

If one creates an inheriting class ...

class ChildConfig(Config):
    child_a: int = 1

... and uses this in the type hint in the task (but nowhere else) ...

@task
def train(cfg: ChildConfig) -> Config:

... the TypeTransformerFailedError in line 56 is raised because the schemas don't match.

It would work if we don't compare the schemas but within the task, cfg would have the child_a attribute.

I feel that not allowing inheritance and being strict about the schema makes sense here.

IIUC this would be the same behaviour as in the SimpleTransformer, where the types have to be exact as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment