Skip to content

Instantly share code, notes, and snippets.

@timtan
Created November 10, 2022 01:14
Show Gist options
  • Save timtan/d00eccadeb98f63c06f1a9432842421a to your computer and use it in GitHub Desktop.
Save timtan/d00eccadeb98f63c06f1a9432842421a to your computer and use it in GitHub Desktop.
TestAble Dag
# copied from https://github.com/apache/airflow/blob/main/airflow/models/dag.py
from __future__ import annotations
import logging
import sys
from datetime import datetime
from typing import cast, Any, TYPE_CHECKING
from airflow import DAG, settings
from airflow.configuration import secrets_backend_list
from airflow.exceptions import AirflowSkipException
from airflow.models import TaskInstance, DagRun
from airflow.secrets.local_filesystem import LocalFilesystemBackend
from airflow.utils.session import provide_session
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.types import DagRunType
from sqlalchemy.orm import Session
NEW_SESSION: settings.SASession = cast(settings.SASession, None)
if TYPE_CHECKING:
from types import ModuleType
from airflow.datasets import Dataset
from airflow.decorators import TaskDecoratorCollection
from airflow.models.dagbag import DagBag
from airflow.models.slamiss import SlaMiss
from airflow.utils.task_group import TaskGroup
log = logging.getLogger(__name__)
class TestableDAG(DAG):
# test is added in airflow 2.3.
# engineer suggest me just back port it is ok
@provide_session
def test(
self,
execution_date: datetime | None = None,
run_conf: dict[str, Any] | None = None,
conn_file_path: str | None = None,
variable_file_path: str | None = None,
session: Session = NEW_SESSION,
) -> None:
"""
Execute one single DagRun for a given DAG and execution date.
:param execution_date: execution date for the DAG run
:param run_conf: configuration to pass to newly created dagrun
:param conn_file_path: file path to a connection file in either yaml or json
:param variable_file_path: file path to a variable file in either yaml or json
:param session: database connection (optional)
"""
def add_logger_if_needed(ti: TaskInstance):
"""
Add a formatted logger to the taskinstance so all logs are surfaced to the command line instead
of into a task file. Since this is a local test run, it is much better for the user to see logs
in the command line, rather than needing to search for a log file.
Args:
ti: The taskinstance that will receive a logger
"""
format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
handler = logging.StreamHandler(sys.stdout)
handler.level = logging.INFO
handler.setFormatter(format)
# only add log handler once
if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers):
self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id)
ti.log.addHandler(handler)
if conn_file_path or variable_file_path:
local_secrets = LocalFilesystemBackend(
variables_file_path=variable_file_path, connections_file_path=conn_file_path
)
secrets_backend_list.insert(0, local_secrets)
execution_date = execution_date or timezone.utcnow()
self.log.debug("Clearing existing task instances for execution date %s", execution_date)
self.clear(
start_date=execution_date,
end_date=execution_date,
dag_run_state=False, # type: ignore
session=session,
)
self.log.debug("Getting dagrun for dag %s", self.dag_id)
dr: DagRun = _get_or_create_dagrun(
dag=self,
start_date=execution_date,
execution_date=execution_date,
run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
session=session,
conf=run_conf,
)
tasks = self.task_dict
self.log.debug("starting dagrun")
# Instead of starting a scheduler, we run the minimal loop possible to check
# for task readiness and dependency management. This is notably faster
# than creating a BackfillJob and allows us to surface logs to the user
while dr.state == State.RUNNING:
schedulable_tis, _ = dr.update_state(session=session)
for ti in schedulable_tis:
add_logger_if_needed(ti)
ti.task = tasks[ti.task_id]
_run_task(ti, session=session)
if conn_file_path or variable_file_path:
# Remove the local variables we have added to the secrets_backend_list
secrets_backend_list.pop(0)
def _run_task(ti: TaskInstance, session):
"""
Run a single task instance, and push result to Xcom for downstream tasks. Bypasses a lot of
extra steps used in `task.run` to keep our local running as fast as possible
This function is only meant for the `dag.test` function as a helper function.
Args:
ti: TaskInstance to run
"""
log.info("*****************************************************")
log.info("Running task %s", ti.task_id)
try:
ti._run_raw_task(session=session)
session.flush()
log.info("%s ran successfully!", ti.task_id)
except AirflowSkipException:
log.info("Task Skipped, continuing")
log.info("*****************************************************")
def _get_or_create_dagrun(
dag: DAG,
conf: dict[Any, Any] | None,
start_date: datetime,
execution_date: datetime,
run_id: str,
session: Session,
) -> DagRun:
"""
Create a DAGRun, but only after clearing the previous instance of said dagrun to prevent collisions.
This function is only meant for the `dag.test` function as a helper function.
:param dag: Dag to be used to find dagrun
:param conf: configuration to pass to newly created dagrun
:param start_date: start date of new dagrun, defaults to execution_date
:param execution_date: execution_date for finding the dagrun
:param run_id: run_id to pass to new dagrun
:param session: sqlalchemy session
:return:
"""
log.info("dagrun id: %s", dag.dag_id)
dr: DagRun = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
.first()
)
if dr:
session.delete(dr)
session.commit()
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=execution_date,
run_id=run_id,
start_date=start_date or execution_date,
session=session,
conf=conf, # type: ignore
)
log.info("created dagrun " + str(dr))
return dr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment