Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active October 16, 2019 16:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/592d24c34a3e22a24ecebf0c6e67047d to your computer and use it in GitHub Desktop.
Save brandonwillard/592d24c34a3e22a24ecebf0c6e67047d to your computer and use it in GitHub Desktop.
Local Airflow Interactive Testing/Development in Python

Interactive Airflow Debugging

Local Airflow Setup

The following function augments the base Airflow config so that dags and plugins directories can be specified dynamically in code.

import os
import sys

from itertools import chain

from pathlib import Path

from importlib import reload
from importlib.util import find_spec

import sh

from airflow.configuration import conf

from airflow.utils.db import initdb


def setup_airflow_test_env(
        airflow_config=None, k8s_fernet_secret_name=None, set_env_vars=True,
        package_name=None, etl_base_path="./", etl_dags_path="./dags"
):
    """Initialize an Airflow environment using either the environment variable
    `AIRFLOW_HOME` or a package's root directory.

    Examples
    --------

    This setup can be used to run test task instances in an independent,
    external prod/testing environment with code like the following:

    .. code-block:: python

        import sys
        import logging
        from datetime import datetime, timezone

        from utils import setup_airflow_test_env

        setup_airflow_test_env()

        #
        # Create DAG(s) per usual:
        #
        from airflow import DAG
        from airflow.models import TaskInstance
        from airflow.operators.bash_operator import BashOperator

        test_dag = DAG(dag_id='test_dag')

        some_task = BashOperator(dag=test_dag, task_id='test_task', bash_command='env',
                                start_date=datetime.now(tz=timezone.utc))

        #
        # With the test env set up, you can do the following:
        #

        # Optional: force logging to stdout
        some_task.log.handlers += [logging.StreamHandler(stream=sys.stdout)]

        # Manually create a task instance and execute it
        some_ti = TaskInstance(task=some_task, execution_date=datetime.now(tz=timezone.utc))
        some_ti_res = some_task.execute(some_ti.get_template_context())

        # If a connection is needed, we can create one manually:
        from airflow.models import Connection
        from airflow.utils.db import merge_conn

        new_conn = Connection(conn_id='test_connection', conn_type='some-type', host='a-host')
        merge_conn(new_conn)


        # If you want to inspect the Airflow DB from here, a SQLAlchemy session is
        # obtained with the following:
        from airflow import settings
        session = settings.Session()

        first_conn = session.query(Connection).first()

        assert first_conn.conn_id == 'test_connection'

        # Also, with `set_env_vars=True` (the default), the new connection should
        # appear in the output of the shell command `airflow connections -l`
        # (e.g. executed via IPython/Jupyter shell magic `!`).

    Parameters
    ----------
    airflow_config : str, optional
        Location of Airflow config file to load.  If `None`, it uses
        `airflow.configuration.conf.load_test_config`.

    k8s_fernet_secret_name : str, optional
        Name of Kubernetes secret containing the Fernet key for Airflow DB
        encryption.

    set_env_vars : bool, optional (True)
        Set Airflow env vars so that cmdline tools (run in the same
        environment) will use these test settings.

    """

    if "AIRFLOW_HOME" in os.environ:
        etl_base_path = Path(os.environ["AIRFLOW_HOME"])
        etl_dags_path = etl_base_path / "dags"
    elif package_name:
        etl_dags_spec = find_spec("dags", package=package_name)
        etl_dags_path = Path(os.path.dirname(etl_dags_spec.loader.path))
        etl_base_path = etl_dags_path.parent.parent

    if not etl_base_path.is_dir():
        raise RuntimeError("Could not find ETL/Airflow home directory.")

    if airflow_config:
        if not os.path.isfile(airflow_config):
            print(
                (
                    "Couldn't find {}; trying relative to ETL" " project base path."
                ).format(airflow_config)
            )
            airflow_cfg = etl_base_path / airflow_config
            conf.read(airflow_cfg.as_posix())
        else:
            conf.read(airflow_config)
    else:
        conf.load_test_config()

        # `load_test_config` takes the following three steps:
        # override any custom settings with defaults
        #   conf.read_string(configuration.parameterized_config(configuration.DEFAULT_CONFIG))
        # then read test config
        #   conf.read_string(configuration.parameterized_config(configuration.TEST_CONFIG))
        # then read any "custom" test settings
        #   conf.read(configuration.TEST_CONFIG_FILE)

        initdb()

    conf.set("core", "airflow_home", etl_base_path.as_posix())
    conf.set("core", "dags_folder", etl_dags_path.as_posix())
    conf.set("core", "plugins_folder", (etl_dags_path.parent / "plugins").as_posix())
    conf.set("core", "base_log_folder", (etl_base_path / "logs").as_posix())
    conf.set("core", "logging_level", "DEBUG")
    conf.set("core", "remote_logging", "False")
    # conf.set('core', 'sql_alchemy_conn', )
    # conf.set('core', 'logging_config_class', )
    # conf.set('webserver', 'base_url', '127.0.0.1')

    if k8s_fernet_secret_name:
        try:
            secret_value = sh.kubectl.get.secret(
                k8s_fernet_secret_name,
                "-o",
                "template",
                "--template",
                "{{.data.key}}",
                _piped=True,
            )
            fernet_key = sh.base64(secret_value, "--decode")

            conf.set("core", "fernet_key", fernet_key.next())
        except Exception as e:
            print("Problem setting Fernet key via K8s: {}".format(e))

    conf.set(
        "scheduler",
        "child_process_log_directory",
        (etl_base_path / "logs" / "scheduler").as_posix(),
    )

    if set_env_vars:
        os.environ["AIRFLOW_HOME"] = conf.get("core", "airflow_home")
        os.environ["AIRFLOW__CORE__SQL_ALCHEMY_CONN"] = conf.get(
            "core", "sql_alchemy_conn"
        )
        os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = conf.get("core", "dags_folder")
        os.environ["AIRFLOW__CORE__PLUGINS_FOLDER"] = conf.get("core", "plugins_folder")
        os.environ["AIRFLOW__CORE__BASE_LOG_FOLDER"] = conf.get(
            "core", "base_log_folder"
        )
        os.environ["AIRFLOW__CORE__FERNET_KEY"] = conf.get("core", "fernet_key")
        os.environ["AIRFLOW__CORE__EXECUTOR"] = conf.get("core", "executor")

    # XXX: When `airflow.settings` is loaded, it loads
    # `airflow.configuration.conf` and sets module-level variables!
    import airflow.settings

    reload(airflow.settings)

    import airflow.plugins_manager

    reload(airflow.plugins_manager)

    # Then check that the plugins were actually found.
    # airflow.plugins_manager.plugins_folder
    # airflow.plugins_manager.plugins
    # airflow.plugins_manager.macros_modules

    # You might need to manually add the plugins (this is what Airflow does).
    for module in chain(
        airflow.plugins_manager.operators_modules,
        airflow.plugins_manager.macros_modules,
        airflow.plugins_manager.sensors_modules,
        airflow.plugins_manager.executors_modules,
        airflow.plugins_manager.hooks_modules,
    ):
        sys.modules[module.__name__] = module
        globals()[module._name] = module

Usage

In the following, we give an example of how the setup function can be used.

from airflow import configuration as conf
from airflow.models import Connection
from airflow.utils.db import merge_conn

from airflow.utils.db import initdb, resetdb, upgradedb

# import os
# from airflow.logging_config import prepare_classpath
# conf.set('core', 'AIRFLOW_HOME', os.environ['AIRFLOW_CORE_AIRFLOW_HOME'])
# prepare_classpath()

# If you have a config file you want to use:
# setup_airflow_test_env(airflow_config='~/projects/citybase/data-science-etl/airflow.cfg',
#                        k8s_fernet_secret_name='fernet-key')

# This will create a SQLite DB locally
setup_airflow_test_env()

# conf.get('core', 'AIRFLOW_HOME')
sql_alchemy_conn = conf.get('core', 'sql_alchemy_conn')

# If you already had an existing local SQLite DB for Airflow, you may want to
# start anew?
if sql_alchemy_conn.startswith('sqlite'):
    resetdb(False)

Adding Connections

Most likely, the DAGs/operators you’ll want to test/debug rely on very specific Airflow connections–perhaps even ones that point to real systems. The following shows how connections can be created on-the-fly for a local Airflow setup.

# XXX: I think this needs to be imported *after* the above is run.
from airflow import settings


session = settings.Session()

an_rds_conn = Connection(conn_id='my_rds',
                         conn_type='postgres',
                         host='hostname.com',
                         schema='blah',
                         port=5432,
                         login='user',
                         password='****')

# Do this again, in case the connections table wasn't created.
# initdb()

merge_conn(an_rds_conn)

We can view all the non-default connections with the following:

# Query for our new connections via sqlalchemy means
[con for con in session.query(Connection).all() if not con.conn_id.endswith('_default')]

Our new connection should be present in the output.

Manually Executing a Task from a DAG

Here, we manually create a task instance for a pre-existing DAG and execute it. By executing a task this way, we are able to interactively debug the underlying DAG and its operator(s).

import sys
import logging
import importlib

import airflow

from datetime import datetime, timezone

from airflow.models import TaskInstance

import some_package.dags as some_dag


airflow.logging_config.handlers = [logging.StreamHandler(stream=sys.stdout)]

# import dask
# dask.config.set(scheduler='synchronous')

# In case our DAG has changed (e.g. we're working on it now), reload it.
importlib.reload(some_dag)

# Our DAG produces multiple tasks (e.g. one for every table in some DB); we chose
# one specific task based on its table parameter.
table_op, = [task for task in some_dag.dag.tasks if task.table.name == 'a_table']

# This will cause normal Airflow logging to appear as stdout
table_op.log.handlers += airflow.logging_config.handlers

table_ti = TaskInstance(task=table_op, execution_date=datetime.now(tz=timezone.utc))

table_ti_res = table_op.execute(table_ti.get_template_context())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment