Skip to content

Instantly share code, notes, and snippets.

@dylanwilder
Created August 30, 2022 20:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dylanwilder/4bf86ee1563fbec12a6497f5091eabf2 to your computer and use it in GitHub Desktop.
Save dylanwilder/4bf86ee1563fbec12a6497f5091eabf2 to your computer and use it in GitHub Desktop.
import flytekit
from flytekit import (
kwtypes
)
from flytekit.core.interface import Interface
import pandas as pd
from typing import Annotated, Any
class Assumptions:
_ASSUMPTIONS_REGISTRY = dict()
_TASK_CACHE = dict()
class _AssumptionsTask(flytekit.PythonInstanceTask):
def __init__(self, assumption):
super().__init__(
name=f"onemodel.get_assumptions.{assumption.lower().replace(' ', '_')}",
task_config=None,
interface=Interface(
inputs={"version_id": str, "assumption": str},
outputs={"assumption": Assumptions._ASSUMPTIONS_REGISTRY[assumption]}
)
)
def execute(self, **kwargs) -> Any:
print(f"fetching assumption {kwargs['version_id']}, {kwargs['assumption']}")
return pd.DataFrame.empty
@classmethod
def register(cls, name, schema):
cls._ASSUMPTIONS_REGISTRY[name] = schema
@classmethod
def get(cls, *, version_id, assumption):
task = cls._AssumptionsTask(assumption)
return task(version_id=version_id, assumption=assumption)
class Assumption:
def __class_getitem__(cls, params):
assert len(params) == 2
schema = Annotated[flytekit.StructuredDataset, params[1]]
setattr(schema, "name", params[0])
Assumptions.register(params[0], schema)
return schema
BasePricesSchema = Assumption[
"Base Prices",
kwtypes(age=int, name=str)
]
@flytekit.task
def t2(pdf: BasePricesSchema) -> int:
return len(pdf)
@flytekit.workflow
def wf(version_id: str) -> int:
bp = Assumptions.get(version_id=version_id, assumption=BasePricesSchema.name)
return t2(pdf=bp)
lp = flytekit.LaunchPlan.get_or_create(
wf,
name="onemodel.models.annotations.lp",
labels=flytekit.Labels({"label1": "a"}),
annotations=flytekit.Annotations({"assumption/bases1234": "Base Prices"})
)
@flytekit.workflow
def run_dp() -> None:
return lp(version_id="abc")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment