Created
August 8, 2020 22:05
-
-
Save gerrymanoim/83b08008b2eefd2a20cb4fc38aa25544 to your computer and use it in GitHub Desktop.
All the code for generic callbacks in airflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from collections import namedtuple | |
from functools import lru_cache | |
from pathlib import Path | |
from typing import List | |
import jinja2 | |
import yaml | |
from airflow import configuration | |
from airflow.models import Variable | |
from airflow.operators.slack_operator import SlackAPIPostOperator | |
from airflow.utils.log.logging_mixin import LoggingMixin | |
log = LoggingMixin().log | |
Route = namedtuple("Route", ["condition", "notifer_type", "send_to"]) | |
class Router: | |
""" | |
Wouldn't it be nice to route some notifications | |
""" | |
# Some example notifiers you may have | |
notifiers = { | |
"slack": SlackNotifer, | |
"email": EmailNotifier, | |
"page": PagerDutyNotifier, | |
"http": HTTPNotifier, | |
} | |
router_types = { | |
"failure": Path(__file__).parent / "routes" / "on_failure.yaml", | |
"success": Path(__file__).parent / "routes" / "on_success.yaml", | |
"retry": Path(__file__).parent / "routes" / "on_retry.yaml", | |
"sla_miss": Path(__file__).parent / "routes" / "on_sla_miss_yaml", | |
} | |
def __init__(self, router_type: str): | |
""" | |
`router_type` is one of | |
- failure | |
- retry | |
- success | |
- sla miss | |
It controls which rules file this router will load. | |
""" | |
log.debug("Creating Router for {}".format(router_type)) | |
routing_file = self.router_types.get(router_type) | |
if not routing_file: | |
raise KeyError("Couldn't find routing file for type {}".format(router_type)) | |
self.routes = self.make_routes(routing_file) | |
self.router_type = router_type | |
def __call__(self, *args): | |
""" | |
Airflow has two types of callbacks: | |
1. on_success/failure/retry: Called by passing the context | |
`fn(context)` | |
2. sla_miss: Called by passing multiple args (via schedule_job.py) | |
This is a bit of a hack, but probably better than inspecting the call | |
stack | |
""" | |
log.info("__call__") | |
if len(args) == 1: | |
# we were passed a single var which is the context | |
context = args[0] | |
if len(args) == 5: | |
# passed multiple args that we pack up | |
# see https://github.com/apache/airflow/blob/7930234726c5e9cb9745cc7944047ac343ab832a/airflow/jobs/scheduler_job.py#L455 | |
keys = ( | |
"dag", | |
"task_list", | |
"blocking_task_list", | |
"slas", | |
"blocking_tis", | |
) | |
context = {key: arg for key, arg in zip(keys, args)} | |
else: | |
raise NotImplementedError( | |
"Passed a number of args we do not know how to deal with. " | |
f"Passed {len(args)}" | |
) | |
context = self.enrich_context(context) | |
if self.routes: | |
self.route(context) | |
else: | |
log.info("Note routes defined for dispatcher {}".format(self.router_type)) | |
@property | |
def airflow_environment(self) -> str: | |
""" | |
An example of something you might want in `enrich_context`. | |
This will never get sent as part of the task instance context so we | |
patch it in. | |
""" | |
airflow_url = configuration.get("webserver", "BASE_URL") | |
if airflow_url == "your-production-url": | |
return "Staging" | |
elif airflow_url == "your-staging-url": | |
return "Production" | |
else: | |
raise ValueError("Unknown base_url {}. Abort.".format(airflow_url)) | |
def enrich_context(self, base_context: dict) -> dict: | |
""" | |
Enrich the context we get from the task with more useful information. | |
This makes writing some routing rules a bit easier. | |
Args: | |
base_context (dict): Generally the task instance context | |
See https://airflow.apache.org/code.html#macros for a list | |
""" | |
enriched_context = dict(base_context) | |
enriched_context["airflow_environment"] = self.airflow_environment | |
enriched_context["airflow_url"] = base_context.get("conf").get( | |
"webserver", "BASE_URL" | |
) | |
enriched_context["dag_id"] = base_context.get("dag").dag_id | |
enriched_context["task_id"] = base_context.get("task").task_id | |
return enriched_context | |
def make_routes(self, routing_file: Path) -> List[Route]: | |
""" | |
Converts route file to a flat structure | |
We flatten the heiarchy of the route file to make it easier to use when | |
evaluating in a dag context. | |
""" | |
if not routing_file.exists() or routing_file.stat().st_size == 0: | |
log.info("Cannot work with routing file {}".format(routing_file)) | |
return [] | |
with routing_file.open() as f: | |
route_dict = yaml.load(f.read(), Loader=yaml.FullLoader) | |
# <3 list comps | |
routes = [ | |
Route(condition, notifier_type, route["send_to"]) | |
for notifier_type, routes in route_dict.items() | |
for route in routes | |
for condition in route["where"] | |
] | |
log.debug( | |
"Created {} routes for dispatcher {}".format(len(routes), self.router_type) | |
) | |
return routes | |
def route(self, ti_context: dict): | |
""" | |
Routes a context when called | |
""" | |
log.info("Dispatching message") | |
messages_to_send = filter( | |
lambda route: eval(route.condition, None, ti_context), self.routes | |
) | |
for message in messages_to_send: | |
log.info("Sending message: {}".format(message)) | |
notifier = self.get_notifier(message.notifier_type) | |
notifier.notify(message.send_to, ti_context) | |
@lru_cache(maxsize=5) | |
def get_notifier(self, notifier_type: str): | |
notifier = self.notifiers.get(notifier_type) | |
if not notifier: | |
raise KeyError("Could not find a notifier for {}".format(notifier_type)) | |
return notifier(self.router_type) | |
class GenericNotifier: | |
""" | |
Shamelessly taken from `callback_wrappers` in moneytree | |
""" | |
# Unique TI identifier | |
ti_id = "{{ dag.dag_id }}.{{ task.task_id }} [{{ ts }}]" | |
ti_url = ( | |
"{{ airflow_url }}/task?" | |
"task_id={{ task.task_id }}" | |
"&dag_id={{ dag.dag_id }}" | |
"&execution_date={{ ts | urlencode }}" | |
) | |
log_url = ( | |
"{{ airflow_url }}/log?" | |
"task_id={{ task.task_id }}" | |
"&dag_id={{ dag.dag_id }}" | |
"&execution_date={{ ts | urlencode }}" | |
) | |
graph_url = "{{ airflow_url }}/graph?dag_id={{ dag.dag_id }}" | |
# Uses magic Slack formatting to create a <pre> section that's also a link! | |
prefix = "<{ti_url}|`{ti_id}`>".format(ti_url=ti_url, ti_id=ti_id) | |
def get_jinja_env(self, context: dict): | |
""" | |
Given a context, return an appropriate Jinja environment. | |
""" | |
dag = context.get("dag") | |
jinja_env = dag.get_template_env() if dag else jinja2.Environment(cache_size=0) | |
return jinja_env | |
def jinja_template(text, context, jinja_env): | |
""" | |
Return the templated string. | |
""" | |
return jinja_env.from_string(text).render(**context) | |
class SlackNotifer(GenericNotifier): | |
def __init__(self, notifier_type: str): | |
self.notifier_type = notifier_type | |
self.message_tmpl = self.get_message_tmpl(notifier_type) | |
def get_message_tmpl(self, notifier_type: str) -> str: | |
if notifier_type == "success": | |
tmpl = ":success:" + self.prefix + " is complete." | |
elif notifier_type == "failure": | |
tmpl = ":x:" + self.prefix + " has failed!" | |
elif notifier_type == "retry": | |
tmpl = ":warning:" + self.prefix + " has retried!" | |
elif notifier_type == "sla_miss": | |
tmpl = ":warning:" + self.prefix + " has missed its sla!" | |
else: | |
raise NotImplementedError( | |
"No message template for type {}".format(notifier_type) | |
) | |
return tmpl | |
def get_username(self, dag_id: str, airflow_environment: str, **context) -> str: | |
return dag_id + "[{}]".format(airflow_environment) | |
@property | |
def slack_link(self): | |
return "<{ti_url}|`{ti_id}`>".format(ti_url=self.ti_url, ti_id=self.ti_id) | |
@property | |
def token(self) -> str: | |
return Variable.get("slack_token", "foobar") | |
def notify(self, send_to: str, context: dict): | |
jinja_env = self.get_jinja_env(context) | |
slack_kwargs = {} | |
slack_kwargs["channel"] = send_to | |
slack_kwargs["text"] = self.jinja_template( | |
self.message_tmpl, context, jinja_env | |
) | |
slack_kwargs["username"] = self.get_username(**context) | |
slack_kwargs["icon_url"] = self.icon_url | |
log.info("Sending message: '%s'", json.dumps(slack_kwargs)) | |
SlackAPIPostOperator( | |
task_id="tmp_slack", token=self.token, **slack_kwargs | |
).execute() | |
class EmailNotifier(GenericNotifier): | |
def __init__(self, notifier_type: str): | |
self.notifier_type = notifier_type | |
self.message_tmpl = self.get_message_tmpl() | |
def get_message_tmpl(self) -> str: | |
"""Get a template for the message | |
Returns: | |
str -- Message template that jinja will render | |
""" | |
html_template = """ | |
<html> | |
<body> | |
Task Instance: <a href="{ti_url}">Link</a> <br /> | |
Log: <a href="{log_url}">Link</a> <br /> | |
Graph: <a href="{graph_url}">Link</a> <br /> | |
</body> | |
<html> | |
""" | |
return html_template.format( | |
ti_url=self.ti_url, log_url=self.log_url, graph_url=self.graph_url | |
) | |
def get_subject(self, context: dict) -> str: | |
return "Airflow {} alert: <{} [{}]>".format( | |
context.get("airflow_environment"), | |
context.get("task_instance_key_str"), | |
self.notifier_type, | |
) | |
def notify(self, send_to: str, context: dict): | |
jinja_env = self.get_jinja_env(context) | |
text = self.jinja_template(self.message_tmpl, context, jinja_env) | |
subject = self.get_subject(context) | |
some_function_that_sends_emails(to=send_to, subject=subject, html_content=text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment