Skip to content

Instantly share code, notes, and snippets.

Created March 3, 2023 19:13
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
What would you like to do?
Pydantic base model type transformer
from flytekit import task, workflow
from flytekit.core.type_engine import (
from pydantic import BaseModel
from typing import Optional, Type, Union
from flytekit import FlyteContext
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar
from flytekit.models.types import LiteralType, SimpleType
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct
# ---
class BaseModelTransformer(TypeTransformer[BaseModel]):
_TYPE_INFO = LiteralType(simple=SimpleType.STRUCT)
def __init__(self):
"""Construct BaseModelTransformer."""
super().__init__(name="basemodel-transform", t=BaseModel)
def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT)
def to_literal(
ctx: FlyteContext,
python_val: BaseModel,
python_type: Type[BaseModel],
expected: LiteralType,
) -> Literal:
"""This method is used to convert from given python type object ``BaseModel`` to the Literal representation."""
s = Struct()
s.update({"schema": python_val.schema(), "data": python_val.dict()})
return Literal(scalar=Scalar(generic=s))
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]
) -> BaseModel:
"""In this function we want to be able to re-hydrate the custom object from Flyte Literal value."""
base_model = MessageToDict(lv.scalar.generic)
schema = base_model["schema"]
data = base_model["data"]
if (expected_schema := expected_python_type.schema()) != schema:
raise TypeTransformerFailedError(
f"The schema `{expected_schema}` of the expeted python type {expected_python_type} is not equal to the received schema `{schema}`."
return expected_python_type().parse_obj(data)
# ---
class ModelConfig(BaseModel):
a: int = 1
b: float = 2.0
c: str = "foo"
class Config(BaseModel):
model_config: Optional[Union[dict[str, ModelConfig], ModelConfig]] = ModelConfig() # <- Something that makes trouble with the data class transformer
def update(self):
print("Updating config in an imaginary way")
def train(cfg: Config) -> Config:
print(f"Training with config {cfg}")
return cfg
def wf(cfg: Config) -> Config:
return train(cfg=cfg)
if __name__ == "__main__":
in_cfg = Config(model_config={"foo": ModelConfig()})
out_cfg = wf(cfg=in_cfg)
assert in_cfg == out_cfg
Copy link

fg91 commented Mar 3, 2023


If one creates an inheriting class ...

class ChildConfig(Config):
    child_a: int = 1

... and uses this in the type hint in the task (but nowhere else) ...

def train(cfg: ChildConfig) -> Config:

... the TypeTransformerFailedError in line 56 is raised because the schemas don't match.

It would work if we don't compare the schemas but within the task, cfg would have the child_a attribute.

I feel that not allowing inheritance and being strict about the schema makes sense here.

IIUC this would be the same behaviour as in the SimpleTransformer, where the types have to be exact as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment