Skip to content

Instantly share code, notes, and snippets.

@raycarter
Last active July 5, 2024 11:44
Show Gist options
  • Save raycarter/75e896d600adec0563545fc58e3795d2 to your computer and use it in GitHub Desktop.
Save raycarter/75e896d600adec0563545fc58e3795d2 to your computer and use it in GitHub Desktop.
demo of the problem in combination of branch and dynamic task mapping in Apache Airflow 2.9.2
from airflow import DAG
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.models.param import Param
from airflow.utils.trigger_rule import TriggerRule
import pendulum
import time
default_args = {
'owner': 'airflow',
'depends_on_past': False,
'email_on_failure': False,
}
dag = DAG(
'branch_trigger_test_one_condition',
max_active_runs=1,
concurrency=1,
default_args=default_args,
description='',
render_template_as_native_obj=True,
start_date=pendulum.datetime(2024, 7, 3, tz="UTC"),
schedule_interval=None,
tags=["bug"],
params={
'condition1': Param(type='string', default='2'),
'mapping_count': Param(type='integer', default=0),
},
)
end = EmptyOperator(task_id='end', dag=dag, trigger_rule=TriggerRule.NONE_FAILED)
def check_condition_1_callable(**kwargs):
condition1 = kwargs['params']['condition1']
if condition1 == '1':
return "branch_1"
else:
return "branch_2"
check_condition_1 = BranchPythonOperator(
task_id='check_condition_1',
python_callable=check_condition_1_callable,
dag=dag
)
branch_1 = EmptyOperator(task_id='branch_1', dag=dag,
trigger_rule=TriggerRule.ALL_SUCCESS,
)
def branch_2_callable(**kwargs):
mapping_count = kwargs['params']['mapping_count']
return list(range(mapping_count))
branch_2 = PythonOperator(
task_id='branch_2',
provide_context=True,
python_callable=branch_2_callable,
dag=dag,
trigger_rule=TriggerRule.ALL_SUCCESS
)
class TestOperator(PythonOperator):
def __init__(self, group_num, **kwargs):
print(f"TestOperator {group_num}")
super().__init__(**kwargs)
def execute(self, context):
super().execute(context)
def branch_2_dyn_mapping_callable(**kwargs):
print(kwargs)
branch_2_dyn_mapping = TestOperator.partial(
task_id='branch_2_dyn_mapping',
python_callable=branch_2_dyn_mapping_callable,
dag=dag,
trigger_rule=TriggerRule.ALL_SUCCESS
).expand(group_num=branch_2.output)
branch_2_task = EmptyOperator(task_id='branch_2_task', dag=dag,
trigger_rule=TriggerRule.NONE_FAILED,
)
# define task dependencies
check_condition_1 >> branch_1 >> end
check_condition_1 >> branch_2 >> branch_2_dyn_mapping >> branch_2_task >> end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment