Skip to content

Instantly share code, notes, and snippets.

@anna-anisienia
Last active November 18, 2021 09:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save anna-anisienia/02b3844a2b28fd0911c61b0b84b8c5a6 to your computer and use it in GitHub Desktop.
Save anna-anisienia/02b3844a2b28fd0911c61b0b84b8c5a6 to your computer and use it in GitHub Desktop.
for Medium article
from airflow.models import DagModel, TaskInstance, DagRun
from airflow.utils.decorators import apply_defaults
from airflow.operators.dagrun_operator import TriggerDagRunOperator
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.db import provide_session
class WaitForCompletion(BaseSensorOperator):
"""
Waits for a different DAG or a task in a different DAG to complete
:param external_dag_id: The dag_id that contains the task you want to wait for
:type external_dag_id: str
:param external_task_id: The task_id that contains the task you want to wait for.
:type external_task_id: str or None
:param check_existence: Set to `True` to check if the external task exists (when
external_task_id is not None) or check if the DAG to wait for exists (when
external_task_id is None), and immediately cease waiting if the external task
or DAG does not exist (default value: False).
:type check_existence: bool
"""
template_fields = ['external_dag_id', 'external_task_id']
ui_color = '#4287f5'
@apply_defaults
def __init__(self, external_dag_id, external_task_id=None, *args, **kwargs):
super(WaitForCompletion, self).__init__(*args, **kwargs)
self.external_dag_id = external_dag_id
self.external_task_id = external_task_id
@provide_session
def poke(self, session=None):
ti = TaskInstance
dr = DagRun
self.log.info('Poking for %s.%s ... ',
self.external_dag_id,
self.external_task_id)
state_of_triggerred_dag = (session.query(ti.state)
.join(dr, (dr.dag_id == ti.dag_id and
dr.execution_date == ti.execution_date))
.filter(ti.dag_id == self.external_dag_id,
ti.task_id == self.external_task_id,
dr.run_id.startswith('trig__'))
.order_by(ti.execution_date.desc()).limit(1).scalar())
session.commit()
return state_of_triggerred_dag == 'success'
def trigger_dag(task_id, trigger_dag_id, wait_for_task='finish',
poke_interval=1, timeout=1800,
trigger_rule="all_success", on_failure_callback=None):
"""
Trigger a child DAG from parent DAG
:param task_id: ideally name it the same as the child DAG you want to trigger
:param trigger_dag_id: exact ID (name) of the DAG you want to trigger
:param wait_for_task: by default set to the DummyOperator task finish
- change it if you use a different task
as your last task within the child DAG
:param poke_interval: by default poke for every second
:param timeout: after half an hour - assumption that no task would run
longer than that - adjust if needed
:param trigger_rule: by default all_success, you can change it do all_done
if you want that it runs regardless of successful or not
:param on_failure_callback: function to call on failure
:return: dag_to_trigger, wait_task
"""
dag_to_trigger = TriggerDagRunOperator(task_id=task_id,
trigger_dag_id=trigger_dag_id,
trigger_rule=trigger_rule,
on_failure_callback=on_failure_callback)
wait_task = WaitForCompletion(task_id=f'wait_for_{task_id}',
external_dag_id=trigger_dag_id,
external_task_id=wait_for_task,
poke_interval=poke_interval,
timeout=timeout)
return dag_to_trigger, wait_task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment