Skip to content

Instantly share code, notes, and snippets.

@bjagrelli
Last active July 30, 2023 18:07
Show Gist options
  • Save bjagrelli/0fc8d79f7e9924f1ed9ae4b25012e529 to your computer and use it in GitHub Desktop.
Save bjagrelli/0fc8d79f7e9924f1ed9ae4b25012e529 to your computer and use it in GitHub Desktop.
Custom Python script implementing a task triggering solution in Apache Airflow, where downstream tasks execute only if all specified 'extract' tasks complete successfully.
from datetime import datetime
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from airflow.exceptions import AirflowFailException
# Define the default arguments for the DAG
default_args = {
'owner': 'airflow',
'start_date': datetime(2023, 7, 26),
}
# Create the Airflow DAG
dag = DAG('example_dag', default_args=default_args, schedule_interval=None)
# Define the task functions
def task_success(**kwargs):
task_instance = kwargs['task_instance']
print(f"Task {task_instance.task_id} succeeded.")
return True
def task_failure(**kwargs):
task_instance = kwargs['task_instance']
print(f"Task {task_instance.task_id} failed.")
raise AirflowFailException(f"Task {task_instance.task_id} failed.")
def check_tasks_state(task_prefix, **context):
"""
This function is used to check the states of Airflow tasks with a given task prefix within the context of a specific DAG run.
Parameters:
task_prefix (str): The prefix of the task IDs to check.
**context (dict): Additional keyword arguments (context) passed to the function. It should contain the "dag_run" key
referring to the current DAG run instance.
Returns:
bool: True if all tasks with the specified prefix have successfully completed; otherwise, raises an AirflowFailException.
Raises:
AirflowFailException: If any of the tasks with the specified prefix did not complete successfully.
"""
# Retrieve the 'dag_run' object from the context, which contains information about the current DAG run.
dag_run = context["dag_run"]
# Iterate over the task instances of the DAG run to check their states.
for task_instance in dag_run.get_task_instances():
# Extract the prefix of the current task's ID.
_task_prefix = task_instance.task_id.split("_")[0]
# Check if the task's ID matches the provided 'task_prefix'.
if task_prefix == _task_prefix:
# If the task is found with a matching prefix, check if it has a state of "success".
if task_instance.state == "success":
# If the task is successful, move on to the next task without doing anything.
pass
else:
# If the task has not completed successfully, raise an 'AirflowFailException' to indicate the failure.
raise AirflowFailException(f"Task {task_instance.task_id} failed.")
# If all tasks with the specified prefix have completed successfully, return True.
return True
# Define the tasks
with dag:
start = DummyOperator(task_id='start')
extract_1 = PythonOperator(
task_id='extract_1',
python_callable=task_success,
provide_context=True
)
extract_2 = PythonOperator(
task_id='extract_2',
python_callable=task_success,
provide_context=True
)
extract_3 = PythonOperator(
task_id='extract_3',
python_callable=task_success,
provide_context=True
)
check = PythonOperator(
task_id='check',
python_callable=check_tasks_state,
op_args=["extract"],
provide_context=True
)
transform = DummyOperator(
task_id='transform'
)
end = DummyOperator(
task_id='end'
)
# Define the dependencies
start >> extract_1 >> extract_2 >> extract_3 >> check >> transform >> end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment