Skip to content

Instantly share code, notes, and snippets.

@d1manson
Last active October 20, 2021 23:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d1manson/81c5982b144671783b37b71de12c7be5 to your computer and use it in GitHub Desktop.
Save d1manson/81c5982b144671783b37b71de12c7be5 to your computer and use it in GitHub Desktop.
Prefect control flow helper, similar to `case`. See https://github.com/PrefectHQ/prefect/issues/5071#issuecomment-947846404
from typing import Any, TYPE_CHECKING, Dict
import prefect
from prefect import Task, Flow
from prefect.triggers import all_successful
from prefect.tasks.control_flow.conditional import CompareValue
from prefect.engine import signals
if TYPE_CHECKING:
from prefect.engine import state # noqa
from prefect import core # noqa
__all__ = ("if_equal",)
class EndIfEqual(Task):
def __init__(self, name: str):
def trigger(upstream_states: Dict["core.Edge", "state.State"]):
for edge, edge_state in upstream_states.items():
if edge.key == 'parent_condition' and edge_state.is_skipped():
raise signals.SKIP(
"Parent if-equal block was skipped, so this block should return SKIP.")
return all_successful(upstream_states)
super().__init__(name, skip_on_upstream_skip=False, trigger=trigger)
def run(self, parent_condition):
pass
class if_equal(object):
"""
This is adapted from:
https://raw.githubusercontent.com/PrefectHQ/prefect/e699fce534f77106e52dda1f0f0b23a3f8bcdf81/src/prefect/tasks/control_flow/case.py
A conditional block in a flow definition.
Used as a context-manager, `if_equal` creates a block of tasks that are only
run if the result of `task` is equal to `value`.
Args:
- task (Task): The task to use in the comparison
- value (Any): A constant the result of `task` will be compared with
Example:
A `if_equal` block is similar to Python's if-blocks. It delimits a block
of tasks that will only be run if the result of `task` is equal to
`value`:
```python
a = task_a()
x = task_x()
with if_equal(x, '42') as conditional:
b = task_b()
b.set_upstream(a)
c = task_c()
c.set_upstream(conditional)
```
In this example, task c will run after task b, whether or not b is skipped.
And if a fails, the failure will propagate through to the end.
The `value` argument can be any non-task object.
See https://github.com/PrefectHQ/prefect/issues/5071#issuecomment-947846404
"""
def __init__(self, task: Task, value: Any, name="if_equal"):
if isinstance(value, Task):
raise TypeError("`value` cannot be a task")
self.task = task
self.value = value
self._name = name
self._tasks = set()
self._flow = None
def add_task(self, task: Task, flow: Flow) -> None:
"""Add a new task under the if_equal statement.
Args:
- task (Task): the task to add
- flow (Flow): the flow to use
"""
if self._flow is None:
self._flow = flow
elif self._flow is not flow:
raise ValueError(
"Multiple flows cannot be used with the same if_equal statement"
)
self._tasks.add(task)
# We need to let all the if_equal blocks up the stack know about this task too
# Warning: this breaks if there are case blocks in the stack
if self.__parent_case:
self.__parent_case.add_task(task, flow)
def __enter__(self):
parent = prefect.context.get("case")
self._end_if_equal = EndIfEqual(name=f"{self._name}:end")(
parent_condition=parent and parent._cond)
self._cond = CompareValue(self.value, name=f"{self._name}:if({self.value})").bind(
value=self.task
)
self.__parent_case = parent
prefect.context.update(case=self)
return self._end_if_equal
def __exit__(self, *args):
if self.__parent_case is None:
prefect.context.pop("case", None)
else:
prefect.context.update(case=self.__parent_case)
# This deals with the skip vs fail issue by copying upstream dependancies onto the end-if
# See https://github.com/PrefectHQ/prefect/issues/5071
self._end_if_equal.set_upstream(self.task)
for task in self._tasks:
upstream_tasks_in_context = self._flow.upstream_tasks(
task).intersection(self._tasks)
upstream_tasks_not_in_context = self._flow.upstream_tasks(
task).difference(self._tasks)
downstream_tasks_in_context = self._flow.downstream_tasks(
task).intersection(self._tasks)
for u_task in upstream_tasks_not_in_context:
# This deals with the skip vs fail issue by copying upstream dependancies onto the end-if
# See https://github.com/PrefectHQ/prefect/issues/5071
self._end_if_equal.set_upstream(u_task, flow=self._flow)
if not downstream_tasks_in_context:
# Nothing else within the context depends on this, so we connect it up to the end-if
self._end_if_equal.set_upstream(task, flow=self._flow)
if not upstream_tasks_in_context:
# We need the condition to be upstream of this task, because there's no other tasks within the context that are
task.set_upstream(self._cond, flow=self._flow)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment