Skip to content

Instantly share code, notes, and snippets.

@gerrymanoim
Created August 8, 2020 22:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gerrymanoim/83b08008b2eefd2a20cb4fc38aa25544 to your computer and use it in GitHub Desktop.
Save gerrymanoim/83b08008b2eefd2a20cb4fc38aa25544 to your computer and use it in GitHub Desktop.
All the code for generic callbacks in airflow
import json
from collections import namedtuple
from functools import lru_cache
from pathlib import Path
from typing import List
import jinja2
import yaml
from airflow import configuration
from airflow.models import Variable
from airflow.operators.slack_operator import SlackAPIPostOperator
from airflow.utils.log.logging_mixin import LoggingMixin
log = LoggingMixin().log
Route = namedtuple("Route", ["condition", "notifer_type", "send_to"])
class Router:
"""
Wouldn't it be nice to route some notifications
"""
# Some example notifiers you may have
notifiers = {
"slack": SlackNotifer,
"email": EmailNotifier,
"page": PagerDutyNotifier,
"http": HTTPNotifier,
}
router_types = {
"failure": Path(__file__).parent / "routes" / "on_failure.yaml",
"success": Path(__file__).parent / "routes" / "on_success.yaml",
"retry": Path(__file__).parent / "routes" / "on_retry.yaml",
"sla_miss": Path(__file__).parent / "routes" / "on_sla_miss_yaml",
}
def __init__(self, router_type: str):
"""
`router_type` is one of
- failure
- retry
- success
- sla miss
It controls which rules file this router will load.
"""
log.debug("Creating Router for {}".format(router_type))
routing_file = self.router_types.get(router_type)
if not routing_file:
raise KeyError("Couldn't find routing file for type {}".format(router_type))
self.routes = self.make_routes(routing_file)
self.router_type = router_type
def __call__(self, *args):
"""
Airflow has two types of callbacks:
1. on_success/failure/retry: Called by passing the context
`fn(context)`
2. sla_miss: Called by passing multiple args (via schedule_job.py)
This is a bit of a hack, but probably better than inspecting the call
stack
"""
log.info("__call__")
if len(args) == 1:
# we were passed a single var which is the context
context = args[0]
if len(args) == 5:
# passed multiple args that we pack up
# see https://github.com/apache/airflow/blob/7930234726c5e9cb9745cc7944047ac343ab832a/airflow/jobs/scheduler_job.py#L455
keys = (
"dag",
"task_list",
"blocking_task_list",
"slas",
"blocking_tis",
)
context = {key: arg for key, arg in zip(keys, args)}
else:
raise NotImplementedError(
"Passed a number of args we do not know how to deal with. "
f"Passed {len(args)}"
)
context = self.enrich_context(context)
if self.routes:
self.route(context)
else:
log.info("Note routes defined for dispatcher {}".format(self.router_type))
@property
def airflow_environment(self) -> str:
"""
An example of something you might want in `enrich_context`.
This will never get sent as part of the task instance context so we
patch it in.
"""
airflow_url = configuration.get("webserver", "BASE_URL")
if airflow_url == "your-production-url":
return "Staging"
elif airflow_url == "your-staging-url":
return "Production"
else:
raise ValueError("Unknown base_url {}. Abort.".format(airflow_url))
def enrich_context(self, base_context: dict) -> dict:
"""
Enrich the context we get from the task with more useful information.
This makes writing some routing rules a bit easier.
Args:
base_context (dict): Generally the task instance context
See https://airflow.apache.org/code.html#macros for a list
"""
enriched_context = dict(base_context)
enriched_context["airflow_environment"] = self.airflow_environment
enriched_context["airflow_url"] = base_context.get("conf").get(
"webserver", "BASE_URL"
)
enriched_context["dag_id"] = base_context.get("dag").dag_id
enriched_context["task_id"] = base_context.get("task").task_id
return enriched_context
def make_routes(self, routing_file: Path) -> List[Route]:
"""
Converts route file to a flat structure
We flatten the heiarchy of the route file to make it easier to use when
evaluating in a dag context.
"""
if not routing_file.exists() or routing_file.stat().st_size == 0:
log.info("Cannot work with routing file {}".format(routing_file))
return []
with routing_file.open() as f:
route_dict = yaml.load(f.read(), Loader=yaml.FullLoader)
# <3 list comps
routes = [
Route(condition, notifier_type, route["send_to"])
for notifier_type, routes in route_dict.items()
for route in routes
for condition in route["where"]
]
log.debug(
"Created {} routes for dispatcher {}".format(len(routes), self.router_type)
)
return routes
def route(self, ti_context: dict):
"""
Routes a context when called
"""
log.info("Dispatching message")
messages_to_send = filter(
lambda route: eval(route.condition, None, ti_context), self.routes
)
for message in messages_to_send:
log.info("Sending message: {}".format(message))
notifier = self.get_notifier(message.notifier_type)
notifier.notify(message.send_to, ti_context)
@lru_cache(maxsize=5)
def get_notifier(self, notifier_type: str):
notifier = self.notifiers.get(notifier_type)
if not notifier:
raise KeyError("Could not find a notifier for {}".format(notifier_type))
return notifier(self.router_type)
class GenericNotifier:
"""
Shamelessly taken from `callback_wrappers` in moneytree
"""
# Unique TI identifier
ti_id = "{{ dag.dag_id }}.{{ task.task_id }} [{{ ts }}]"
ti_url = (
"{{ airflow_url }}/task?"
"task_id={{ task.task_id }}"
"&dag_id={{ dag.dag_id }}"
"&execution_date={{ ts | urlencode }}"
)
log_url = (
"{{ airflow_url }}/log?"
"task_id={{ task.task_id }}"
"&dag_id={{ dag.dag_id }}"
"&execution_date={{ ts | urlencode }}"
)
graph_url = "{{ airflow_url }}/graph?dag_id={{ dag.dag_id }}"
# Uses magic Slack formatting to create a <pre> section that's also a link!
prefix = "<{ti_url}|`{ti_id}`>".format(ti_url=ti_url, ti_id=ti_id)
def get_jinja_env(self, context: dict):
"""
Given a context, return an appropriate Jinja environment.
"""
dag = context.get("dag")
jinja_env = dag.get_template_env() if dag else jinja2.Environment(cache_size=0)
return jinja_env
def jinja_template(text, context, jinja_env):
"""
Return the templated string.
"""
return jinja_env.from_string(text).render(**context)
class SlackNotifer(GenericNotifier):
def __init__(self, notifier_type: str):
self.notifier_type = notifier_type
self.message_tmpl = self.get_message_tmpl(notifier_type)
def get_message_tmpl(self, notifier_type: str) -> str:
if notifier_type == "success":
tmpl = ":success:" + self.prefix + " is complete."
elif notifier_type == "failure":
tmpl = ":x:" + self.prefix + " has failed!"
elif notifier_type == "retry":
tmpl = ":warning:" + self.prefix + " has retried!"
elif notifier_type == "sla_miss":
tmpl = ":warning:" + self.prefix + " has missed its sla!"
else:
raise NotImplementedError(
"No message template for type {}".format(notifier_type)
)
return tmpl
def get_username(self, dag_id: str, airflow_environment: str, **context) -> str:
return dag_id + "[{}]".format(airflow_environment)
@property
def slack_link(self):
return "<{ti_url}|`{ti_id}`>".format(ti_url=self.ti_url, ti_id=self.ti_id)
@property
def token(self) -> str:
return Variable.get("slack_token", "foobar")
def notify(self, send_to: str, context: dict):
jinja_env = self.get_jinja_env(context)
slack_kwargs = {}
slack_kwargs["channel"] = send_to
slack_kwargs["text"] = self.jinja_template(
self.message_tmpl, context, jinja_env
)
slack_kwargs["username"] = self.get_username(**context)
slack_kwargs["icon_url"] = self.icon_url
log.info("Sending message: '%s'", json.dumps(slack_kwargs))
SlackAPIPostOperator(
task_id="tmp_slack", token=self.token, **slack_kwargs
).execute()
class EmailNotifier(GenericNotifier):
def __init__(self, notifier_type: str):
self.notifier_type = notifier_type
self.message_tmpl = self.get_message_tmpl()
def get_message_tmpl(self) -> str:
"""Get a template for the message
Returns:
str -- Message template that jinja will render
"""
html_template = """
<html>
<body>
Task Instance: <a href="{ti_url}">Link</a> <br />
Log: <a href="{log_url}">Link</a> <br />
Graph: <a href="{graph_url}">Link</a> <br />
</body>
<html>
"""
return html_template.format(
ti_url=self.ti_url, log_url=self.log_url, graph_url=self.graph_url
)
def get_subject(self, context: dict) -> str:
return "Airflow {} alert: <{} [{}]>".format(
context.get("airflow_environment"),
context.get("task_instance_key_str"),
self.notifier_type,
)
def notify(self, send_to: str, context: dict):
jinja_env = self.get_jinja_env(context)
text = self.jinja_template(self.message_tmpl, context, jinja_env)
subject = self.get_subject(context)
some_function_that_sends_emails(to=send_to, subject=subject, html_content=text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment