Last active
November 18, 2021 09:16
-
-
Save anna-anisienia/02b3844a2b28fd0911c61b0b84b8c5a6 to your computer and use it in GitHub Desktop.
for Medium article
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from airflow.models import DagModel, TaskInstance, DagRun | |
from airflow.utils.decorators import apply_defaults | |
from airflow.operators.dagrun_operator import TriggerDagRunOperator | |
from airflow.sensors.base_sensor_operator import BaseSensorOperator | |
from airflow.utils.db import provide_session | |
class WaitForCompletion(BaseSensorOperator): | |
""" | |
Waits for a different DAG or a task in a different DAG to complete | |
:param external_dag_id: The dag_id that contains the task you want to wait for | |
:type external_dag_id: str | |
:param external_task_id: The task_id that contains the task you want to wait for. | |
:type external_task_id: str or None | |
:param check_existence: Set to `True` to check if the external task exists (when | |
external_task_id is not None) or check if the DAG to wait for exists (when | |
external_task_id is None), and immediately cease waiting if the external task | |
or DAG does not exist (default value: False). | |
:type check_existence: bool | |
""" | |
template_fields = ['external_dag_id', 'external_task_id'] | |
ui_color = '#4287f5' | |
@apply_defaults | |
def __init__(self, external_dag_id, external_task_id=None, *args, **kwargs): | |
super(WaitForCompletion, self).__init__(*args, **kwargs) | |
self.external_dag_id = external_dag_id | |
self.external_task_id = external_task_id | |
@provide_session | |
def poke(self, session=None): | |
ti = TaskInstance | |
dr = DagRun | |
self.log.info('Poking for %s.%s ... ', | |
self.external_dag_id, | |
self.external_task_id) | |
state_of_triggerred_dag = (session.query(ti.state) | |
.join(dr, (dr.dag_id == ti.dag_id and | |
dr.execution_date == ti.execution_date)) | |
.filter(ti.dag_id == self.external_dag_id, | |
ti.task_id == self.external_task_id, | |
dr.run_id.startswith('trig__')) | |
.order_by(ti.execution_date.desc()).limit(1).scalar()) | |
session.commit() | |
return state_of_triggerred_dag == 'success' | |
def trigger_dag(task_id, trigger_dag_id, wait_for_task='finish', | |
poke_interval=1, timeout=1800, | |
trigger_rule="all_success", on_failure_callback=None): | |
""" | |
Trigger a child DAG from parent DAG | |
:param task_id: ideally name it the same as the child DAG you want to trigger | |
:param trigger_dag_id: exact ID (name) of the DAG you want to trigger | |
:param wait_for_task: by default set to the DummyOperator task finish | |
- change it if you use a different task | |
as your last task within the child DAG | |
:param poke_interval: by default poke for every second | |
:param timeout: after half an hour - assumption that no task would run | |
longer than that - adjust if needed | |
:param trigger_rule: by default all_success, you can change it do all_done | |
if you want that it runs regardless of successful or not | |
:param on_failure_callback: function to call on failure | |
:return: dag_to_trigger, wait_task | |
""" | |
dag_to_trigger = TriggerDagRunOperator(task_id=task_id, | |
trigger_dag_id=trigger_dag_id, | |
trigger_rule=trigger_rule, | |
on_failure_callback=on_failure_callback) | |
wait_task = WaitForCompletion(task_id=f'wait_for_{task_id}', | |
external_dag_id=trigger_dag_id, | |
external_task_id=wait_for_task, | |
poke_interval=poke_interval, | |
timeout=timeout) | |
return dag_to_trigger, wait_task |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment