Last active
July 30, 2023 18:07
-
-
Save bjagrelli/0fc8d79f7e9924f1ed9ae4b25012e529 to your computer and use it in GitHub Desktop.
Custom Python script implementing a task triggering solution in Apache Airflow, where downstream tasks execute only if all specified 'extract' tasks complete successfully.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
from airflow import DAG | |
from airflow.operators.dummy_operator import DummyOperator | |
from airflow.operators.python_operator import PythonOperator | |
from airflow.exceptions import AirflowFailException | |
# Define the default arguments for the DAG | |
default_args = { | |
'owner': 'airflow', | |
'start_date': datetime(2023, 7, 26), | |
} | |
# Create the Airflow DAG | |
dag = DAG('example_dag', default_args=default_args, schedule_interval=None) | |
# Define the task functions | |
def task_success(**kwargs): | |
task_instance = kwargs['task_instance'] | |
print(f"Task {task_instance.task_id} succeeded.") | |
return True | |
def task_failure(**kwargs): | |
task_instance = kwargs['task_instance'] | |
print(f"Task {task_instance.task_id} failed.") | |
raise AirflowFailException(f"Task {task_instance.task_id} failed.") | |
def check_tasks_state(task_prefix, **context): | |
""" | |
This function is used to check the states of Airflow tasks with a given task prefix within the context of a specific DAG run. | |
Parameters: | |
task_prefix (str): The prefix of the task IDs to check. | |
**context (dict): Additional keyword arguments (context) passed to the function. It should contain the "dag_run" key | |
referring to the current DAG run instance. | |
Returns: | |
bool: True if all tasks with the specified prefix have successfully completed; otherwise, raises an AirflowFailException. | |
Raises: | |
AirflowFailException: If any of the tasks with the specified prefix did not complete successfully. | |
""" | |
# Retrieve the 'dag_run' object from the context, which contains information about the current DAG run. | |
dag_run = context["dag_run"] | |
# Iterate over the task instances of the DAG run to check their states. | |
for task_instance in dag_run.get_task_instances(): | |
# Extract the prefix of the current task's ID. | |
_task_prefix = task_instance.task_id.split("_")[0] | |
# Check if the task's ID matches the provided 'task_prefix'. | |
if task_prefix == _task_prefix: | |
# If the task is found with a matching prefix, check if it has a state of "success". | |
if task_instance.state == "success": | |
# If the task is successful, move on to the next task without doing anything. | |
pass | |
else: | |
# If the task has not completed successfully, raise an 'AirflowFailException' to indicate the failure. | |
raise AirflowFailException(f"Task {task_instance.task_id} failed.") | |
# If all tasks with the specified prefix have completed successfully, return True. | |
return True | |
# Define the tasks | |
with dag: | |
start = DummyOperator(task_id='start') | |
extract_1 = PythonOperator( | |
task_id='extract_1', | |
python_callable=task_success, | |
provide_context=True | |
) | |
extract_2 = PythonOperator( | |
task_id='extract_2', | |
python_callable=task_success, | |
provide_context=True | |
) | |
extract_3 = PythonOperator( | |
task_id='extract_3', | |
python_callable=task_success, | |
provide_context=True | |
) | |
check = PythonOperator( | |
task_id='check', | |
python_callable=check_tasks_state, | |
op_args=["extract"], | |
provide_context=True | |
) | |
transform = DummyOperator( | |
task_id='transform' | |
) | |
end = DummyOperator( | |
task_id='end' | |
) | |
# Define the dependencies | |
start >> extract_1 >> extract_2 >> extract_3 >> check >> transform >> end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment