Skip to content

Instantly share code, notes, and snippets.

@jbmlaird
Last active January 22, 2019 22:40
Show Gist options
  • Save jbmlaird/7453bf19ff80ffed4adbcd2c4cc3ff56 to your computer and use it in GitHub Desktop.
Save jbmlaird/7453bf19ff80ffed4adbcd2c4cc3ff56 to your computer and use it in GitHub Desktop.
Pull from Pub/Sub
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
import datetime as dt
import json
import logging
from airflow import AirflowException, settings
from airflow.models import BaseOperator, DagBag, DagRun, TaskInstance
from airflow.utils.decorators import apply_defaults
from airflow.utils.state import State
# Must be declared like this (as opposed to from pubsub_sensor) as imports always check sys.path
# https://stackoverflow.com/a/46212814/4624156
from pubsub_sensor import PubSubHook
from sqlalchemy import or_
class DagRunOrder(object):
def __init__(self, run_id=None, payload=None):
self.run_id = run_id
self.payload = payload
class PubSubTrigger(BaseOperator):
"""
Triggers a DAG run for a specified ``dag_id`` if a criteria is met
:param trigger_dag_id: the dag_id to trigger
:type trigger_dag_id: str
:param python_callable: a reference to a python function that will be
called while passing it the ``context`` object and a placeholder
object ``obj`` for your callable to fill and return if you want
a DagRun created. This ``obj`` object contains a ``run_id`` and
``payload`` attribute that you can modify in your function.
The ``run_id`` should be a unique identifier for that DAG run, and
the payload has to be a picklable object that will be made available
to your tasks while executing that DAG run. Your function header
should look like ``def foo(context, dag_run_obj):``
:type python_callable: python callable
"""
dfp_pubsub_name = "dfp_dt"
dfp_dag_name = "dfp-import"
youtube_pubsub_name = "youtube_content_owner"
youtube_dag_name = "youtube-video-import"
template_fields = tuple()
template_ext = tuple()
ui_color = '#ffefeb'
@apply_defaults
def __init__(
self,
project,
subscription,
ack_messages=False,
return_immediately=False,
max_messages=10,
gcp_conn_id='google_cloud_default',
delegate_to=None,
*args, **kwargs):
super(PubSubTrigger, self).__init__(*args, **kwargs)
self.project = project
self.subscription = subscription
self.ack_messages = ack_messages
self.return_immediately = return_immediately
self.max_messages = max_messages
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
def decode_pubsub(self, pubsub_message):
"""
Convert the pubsub message data contents to a JSON object
:param pubsub_message: Raw PubSub message
:return: JSON object of the data contents
"""
b64_decoded_data = pubsub_message.get('message').get('data').decode('base64')
logging.info('pubsub_decoded: {}'.format(b64_decoded_data))
return json.loads(b64_decoded_data)
def get_run_status(self, data_json):
state = data_json.get('state')
logging.info('state: {}'.format(state))
return state == 'SUCCEEDED'
def get_target_dag(self, data_json):
data_source_id = data_json.get('dataSourceId')
logging.info('data_source_id: {}, type: {}'.format(data_source_id, type(data_source_id)))
if data_source_id == self.youtube_pubsub_name:
target_dag = self.youtube_dag_name
elif data_source_id == self.dfp_pubsub_name:
target_dag = self.dfp_dag_name
else:
target_dag = None
return target_dag
def get_run_time(self, data_json):
"""
:param data_json: PubSub message
:return: runTime extracted from the PubSub
"""
run_time_string = data_json.get('runTime')[:10]
run_time_datetime = dt.datetime.strptime(run_time_string, "%Y-%m-%d")
return run_time_datetime
def get_ack_id(self, pubsub_message):
"""
Fetch the acknowledgement ID from the pubsub message
:param pubsub_message: raw PubSub message
:return: acknowledgement ID
"""
logging.info("AckId: {}".format(pubsub_message.get('ackId')))
return pubsub_message.get('ackId')
def acknowledge_run_ack_ids(self, ack_ids, hook):
"""
:param ack_ids: List of ack_ids of messages that have had a dag run triggered
"""
if ack_ids:
hook.acknowledge(self.project, self.subscription, ack_ids)
logging.info("Acknowledged IDs: {}".format([ack_id for ack_id in ack_ids]))
def execute(self, context):
hook = PubSubHook(gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to)
messages = hook.pull(
self.project, self.subscription, self.max_messages,
self.return_immediately)
triggered_ack_ids = []
logging.info('Number of messages from PubSub: {}'.format(len(messages)))
for message_json in messages:
data_json = self.decode_pubsub(message_json)
if not self.get_run_status(data_json):
# If the transfer failed then acknowledge this message and go to the next one
logging.info('skipping message: {}'.format(data_json))
triggered_ack_ids.append(self.get_ack_id(message_json))
continue
target_dag = self.get_target_dag(data_json)
run_time_datetime = self.get_run_time(data_json)
self.trigger_dag(dag_id=target_dag, execution_date=run_time_datetime)
triggered_ack_ids.append(self.get_ack_id(message_json))
self.acknowledge_run_ack_ids(triggered_ack_ids, hook)
def trigger_dag(self, dag_id, run_id=None, execution_date=None):
dagbag = DagBag()
if dag_id not in dagbag.dags:
raise AirflowException("Dag id {} not found".format(dag_id))
dag = dagbag.get_dag(dag_id)
logging.info("dag_id: {}, dag: {}".format(dag_id, dag.__dict__))
if execution_date is None:
logging.info("Creating new execution_date")
execution_date = dt.datetime.utcnow()
assert isinstance(execution_date, dt.datetime)
execution_date = execution_date.replace(microsecond=0)
if not run_id:
run_id = "pstrig__{0}".format(dt.datetime.utcnow().isoformat())
logging.info("No run_id provided. Using: {}".format(run_id))
trigger = dag.create_dagrun(
run_id=run_id,
execution_date=execution_date,
state=State.RUNNING,
external_trigger=True,
)
logging.info("{} started".format(dag_id))
return trigger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment