Skip to content

Instantly share code, notes, and snippets.

@dylanwilder
Last active August 31, 2022 16:22
Show Gist options
  • Save dylanwilder/eb276fe5232acca999d6026e00efbe43 to your computer and use it in GitHub Desktop.
Save dylanwilder/eb276fe5232acca999d6026e00efbe43 to your computer and use it in GitHub Desktop.
import flytekit
from flytekit import (
StructuredDataset,
PythonInstanceTask,
kwtypes,
)
from flytekit.core.interface import Interface
from flytekit.core.base_task import TaskResolverMixin, Task
from flytekit.core.context_manager import SerializationSettings, FlyteContextManager
import pandas as pd
from typing import Any, List, Optional
from typing_extensions import Annotated
class SchemaRegistry:
_NAME2SCHEMA = dict()
_SCHEMA2NAME = dict()
@classmethod
def register(cls, name, schema):
if cls._NAME2SCHEMA.get(name) is not None:
raise AssertionError(f"Assumption definition for {name} already exists!")
cls._NAME2SCHEMA[name] = schema
cls._SCHEMA2NAME[id(schema)] = name
@classmethod
def get_schema(cls, name):
return cls._NAME2SCHEMA[name]
@classmethod
def get_name(cls, schema):
return cls._SCHEMA2NAME[id(schema)]
class Assumption:
def __class_getitem__(cls, params):
assert len(params) == 2
schema = Annotated[StructuredDataset, params[1]]
SchemaRegistry.register(params[0], schema)
return schema
class AssumptionTaskResolver(TaskResolverMixin):
@property
def location(self) -> str:
return f"{self.__module__}.assumptions_resolver"
@property
def name(self) -> str:
return "AssumptionTaskResolver"
def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]:
return [t.assumption_name]
def load_task(self, loader_args: List[str]) -> Task:
assumption_name = loader_args[0]
return AssumptionTask.new(assumption_name)
def task_name(self, t: Task) -> Optional[str]:
return t.name
assumptions_resolver = AssumptionTaskResolver()
class AssumptionTask:
_TASK_CACHE = dict()
class _Task(PythonInstanceTask):
def __init__(self, assumption_name: str):
self.assumption_name = assumption_name
super().__init__(
name=f"onemodel.get_assumptions.{assumption_name.lower().replace(' ', '_')}",
task_config=None,
task_resolver=assumptions_resolver,
interface=Interface(
inputs={"version_id": str, "assumption": str},
outputs={"assumption": SchemaRegistry.get_schema(assumption_name)}
)
)
def execute(self, **kwargs) -> Any:
# TODO Fetch and return the assumption
print(f"fetching assumption {kwargs['version_id']}, {kwargs['assumption']}")
return StructuredDataset(pd.DataFrame.empty)
@classmethod
def new(cls, assumption_name: str) -> _Task:
task = cls._TASK_CACHE.get(assumption_name)
if task is None:
ctx = FlyteContextManager.current_context()
state = ctx.compilation_state.with_params("", resolver=assumptions_resolver)
with FlyteContextManager.with_context(ctx.with_compilation_state(state)):
task = cls._Task(assumption_name)
cls._TASK_CACHE[assumption_name] = task
return task
class Assumptions:
@classmethod
def get(cls, *, version_id: str, assumption: Assumption):
name = SchemaRegistry.get_name(assumption)
task = AssumptionTask.new(name)
return task(version_id=version_id, assumption=name)
BasePricesSchema = Assumption[
"Base Prices",
kwtypes(age=int, name=str)
]
@flytekit.task
def t2(pdf: BasePricesSchema) -> int:
return len(pdf)
@flytekit.workflow
def wf(version_id: str) -> int:
bp = Assumptions.get(version_id=version_id, assumption=BasePricesSchema)
return t2(pdf=bp)
lp = flytekit.LaunchPlan.get_or_create(
wf,
name="onemodel.models.annotations.lp",
labels=flytekit.Labels({"label1": "a"}),
annotations=flytekit.Annotations({"assumption/bases1234": "Base Prices"})
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment