-
-
Save dylanwilder/eb276fe5232acca999d6026e00efbe43 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import flytekit | |
from flytekit import ( | |
StructuredDataset, | |
PythonInstanceTask, | |
kwtypes, | |
) | |
from flytekit.core.interface import Interface | |
from flytekit.core.base_task import TaskResolverMixin, Task | |
from flytekit.core.context_manager import SerializationSettings, FlyteContextManager | |
import pandas as pd | |
from typing import Any, List, Optional | |
from typing_extensions import Annotated | |
class SchemaRegistry: | |
_NAME2SCHEMA = dict() | |
_SCHEMA2NAME = dict() | |
@classmethod | |
def register(cls, name, schema): | |
if cls._NAME2SCHEMA.get(name) is not None: | |
raise AssertionError(f"Assumption definition for {name} already exists!") | |
cls._NAME2SCHEMA[name] = schema | |
cls._SCHEMA2NAME[id(schema)] = name | |
@classmethod | |
def get_schema(cls, name): | |
return cls._NAME2SCHEMA[name] | |
@classmethod | |
def get_name(cls, schema): | |
return cls._SCHEMA2NAME[id(schema)] | |
class Assumption: | |
def __class_getitem__(cls, params): | |
assert len(params) == 2 | |
schema = Annotated[StructuredDataset, params[1]] | |
SchemaRegistry.register(params[0], schema) | |
return schema | |
class AssumptionTaskResolver(TaskResolverMixin): | |
@property | |
def location(self) -> str: | |
return f"{self.__module__}.assumptions_resolver" | |
@property | |
def name(self) -> str: | |
return "AssumptionTaskResolver" | |
def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]: | |
return [t.assumption_name] | |
def load_task(self, loader_args: List[str]) -> Task: | |
assumption_name = loader_args[0] | |
return AssumptionTask.new(assumption_name) | |
def task_name(self, t: Task) -> Optional[str]: | |
return t.name | |
assumptions_resolver = AssumptionTaskResolver() | |
class AssumptionTask: | |
_TASK_CACHE = dict() | |
class _Task(PythonInstanceTask): | |
def __init__(self, assumption_name: str): | |
self.assumption_name = assumption_name | |
super().__init__( | |
name=f"onemodel.get_assumptions.{assumption_name.lower().replace(' ', '_')}", | |
task_config=None, | |
task_resolver=assumptions_resolver, | |
interface=Interface( | |
inputs={"version_id": str, "assumption": str}, | |
outputs={"assumption": SchemaRegistry.get_schema(assumption_name)} | |
) | |
) | |
def execute(self, **kwargs) -> Any: | |
# TODO Fetch and return the assumption | |
print(f"fetching assumption {kwargs['version_id']}, {kwargs['assumption']}") | |
return StructuredDataset(pd.DataFrame.empty) | |
@classmethod | |
def new(cls, assumption_name: str) -> _Task: | |
task = cls._TASK_CACHE.get(assumption_name) | |
if task is None: | |
ctx = FlyteContextManager.current_context() | |
state = ctx.compilation_state.with_params("", resolver=assumptions_resolver) | |
with FlyteContextManager.with_context(ctx.with_compilation_state(state)): | |
task = cls._Task(assumption_name) | |
cls._TASK_CACHE[assumption_name] = task | |
return task | |
class Assumptions: | |
@classmethod | |
def get(cls, *, version_id: str, assumption: Assumption): | |
name = SchemaRegistry.get_name(assumption) | |
task = AssumptionTask.new(name) | |
return task(version_id=version_id, assumption=name) | |
BasePricesSchema = Assumption[ | |
"Base Prices", | |
kwtypes(age=int, name=str) | |
] | |
@flytekit.task | |
def t2(pdf: BasePricesSchema) -> int: | |
return len(pdf) | |
@flytekit.workflow | |
def wf(version_id: str) -> int: | |
bp = Assumptions.get(version_id=version_id, assumption=BasePricesSchema) | |
return t2(pdf=bp) | |
lp = flytekit.LaunchPlan.get_or_create( | |
wf, | |
name="onemodel.models.annotations.lp", | |
labels=flytekit.Labels({"label1": "a"}), | |
annotations=flytekit.Annotations({"assumption/bases1234": "Base Prices"}) | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment