Skip to content

Instantly share code, notes, and snippets.

@dylanbstorey
Last active March 24, 2023 14:23
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dylanbstorey/c6f3d02c0217b6f79b7632002721b3e2 to your computer and use it in GitHub Desktop.
Save dylanbstorey/c6f3d02c0217b6f79b7632002721b3e2 to your computer and use it in GitHub Desktop.
Driver Dag Pattern

Pattern is useful when you want to trigger many copies of the same DAG with different parameters.

Driver DAG

This dag is run on a schedule, fetches the appropriate configurations needed and then triggers the external DAG as a configuration. There is a final "wait" step that is optional that simply ensures there is a blocking step at the end of the submissions to ensure that concurrency is kept at "1" for the parameterized DAG. This pattern uses the dynamic task mapping paradigm to determine how many external dags should be triggered during execution. As a general note, the generate_list task could be replaced with any operator even ones that use external service calls to determine if a specific configuration should be run at this execution cycle.

TriggerDagRun Shim

A thin wrapper around the TriggerDagRun operator that allows us to take a iterable of dicts and split it out between multiple args needed by the base TriggerDagRun operator.

Parameterized DAG

This is the DAG that actually does the work. It takes a config dictionary in and uses the arguements in that dictionary to change the marginal behavior of called operators. It should not be scheduled, but does need to be activated, in order to function.

from collections import defaultdict
import datetime
import time
from airflow import DAG, AirflowException, XComArg
from airflow.models.dagrun import DagRun
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.operators.python import PythonOperator
default_args = {
"owner": "",
"email": "",
"email_on_failure": False,
"email_on_retry": False,
"retries": 0,
}
class TriggerDagRunOperatorShim(TriggerDagRunOperator):
def __init__(self, *args, extra_args=None, **kwargs) -> None:
if not extra_args:
raise AirflowException(
"the extra_args key word argument is required to use this shim."
)
kwargs[
"trigger_run_id"
] = f"{extra_args.get('task_id')}-{datetime.datetime.now()}"
kwargs["conf"] = extra_args.get("payload")
super().__init__(*args, **kwargs)
def _generate_list():
"""generates a list"""
return [
dict(task_id="task_1", payload=dict(arg1=1, arg2=2)),
dict(task_id="task_2", payload=dict(arg1=3, arg2=4)),
dict(task_id="task_3", payload=dict(arg1=5, arg2=6)),
dict(task_id="task_4", payload=dict(arg1=7, arg2=8)),
]
def _await_completion(**kwargs):
task_instance = kwargs.get("ti")
trigger_run_ids = task_instance.xcom_pull(
task_ids="execute_external_dags", key="trigger_run_id"
)
while True:
task_states = defaultdict(list)
for tri in trigger_run_ids:
dag_run = DagRun.find(run_id=tri)
task_state = dag_run[0].state
task_states[task_state].append(tri)
if not task_states.get("queued", False) and not task_states.get(
"running", False
):
print(f"All tasks completed! {task_states}")
break
else:
logging_message = {k: len(v) for k, v in task_states.items()}
print(f"current queue state : {logging_message}")
time.sleep(10)
return
with DAG(
dag_id="poc-driver-dag",
start_date=datetime.datetime(2021, 1, 1),
schedule_interval="0 * * * *",
max_active_runs=1,
default_args=default_args,
catchup=False,
) as dag:
generate_list = PythonOperator(
dag=dag, task_id="determine_eligibility", python_callable=_generate_list
)
execute_external_dags = TriggerDagRunOperatorShim.partial(
dag=dag,
trigger_dag_id="poc-driven-dag",
wait_for_completion=False,
task_id="execute_external_dags",
).expand(extra_args=XComArg(generate_list))
await_completion = PythonOperator(
dag=dag, task_id="await_completion", python_callable=_await_completion
)
generate_list >> execute_external_dags >> await_completion
import datetime
import time
from airflow.operators.python import PythonOperator
from airflow import DAG
default_args = {
"owner": "",
"email": "",
"email_on_failure": False,
"email_on_retry": False,
"retries": 0,
}
def _sleep(**kwargs):
seconds = kwargs.get("dag_run", {}).conf.get('arg2')
time.sleep(seconds)
print(f"slept for {seconds} seconds")
with DAG(
dag_id="poc-driven-dag",
start_date=datetime.datetime(2021, 1, 1),
schedule_interval=None,
max_active_runs=1024,
default_args=default_args,
catchup=False,
) as dag:
sleep = PythonOperator(task_id="sleep", dag=dag, python_callable=_sleep)
sleep
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment