-
-
Save dylanwilder/4bf86ee1563fbec12a6497f5091eabf2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import flytekit | |
from flytekit import ( | |
kwtypes | |
) | |
from flytekit.core.interface import Interface | |
import pandas as pd | |
from typing import Annotated, Any | |
class Assumptions: | |
_ASSUMPTIONS_REGISTRY = dict() | |
_TASK_CACHE = dict() | |
class _AssumptionsTask(flytekit.PythonInstanceTask): | |
def __init__(self, assumption): | |
super().__init__( | |
name=f"onemodel.get_assumptions.{assumption.lower().replace(' ', '_')}", | |
task_config=None, | |
interface=Interface( | |
inputs={"version_id": str, "assumption": str}, | |
outputs={"assumption": Assumptions._ASSUMPTIONS_REGISTRY[assumption]} | |
) | |
) | |
def execute(self, **kwargs) -> Any: | |
print(f"fetching assumption {kwargs['version_id']}, {kwargs['assumption']}") | |
return pd.DataFrame.empty | |
@classmethod | |
def register(cls, name, schema): | |
cls._ASSUMPTIONS_REGISTRY[name] = schema | |
@classmethod | |
def get(cls, *, version_id, assumption): | |
task = cls._AssumptionsTask(assumption) | |
return task(version_id=version_id, assumption=assumption) | |
class Assumption: | |
def __class_getitem__(cls, params): | |
assert len(params) == 2 | |
schema = Annotated[flytekit.StructuredDataset, params[1]] | |
setattr(schema, "name", params[0]) | |
Assumptions.register(params[0], schema) | |
return schema | |
BasePricesSchema = Assumption[ | |
"Base Prices", | |
kwtypes(age=int, name=str) | |
] | |
@flytekit.task | |
def t2(pdf: BasePricesSchema) -> int: | |
return len(pdf) | |
@flytekit.workflow | |
def wf(version_id: str) -> int: | |
bp = Assumptions.get(version_id=version_id, assumption=BasePricesSchema.name) | |
return t2(pdf=bp) | |
lp = flytekit.LaunchPlan.get_or_create( | |
wf, | |
name="onemodel.models.annotations.lp", | |
labels=flytekit.Labels({"label1": "a"}), | |
annotations=flytekit.Annotations({"assumption/bases1234": "Base Prices"}) | |
) | |
@flytekit.workflow | |
def run_dp() -> None: | |
return lp(version_id="abc") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment