Skip to content

Instantly share code, notes, and snippets.

@nave91
Created February 27, 2019 04:27
Show Gist options
  • Save nave91/8c580ce14d5a84684dc5fd0d29204461 to your computer and use it in GitHub Desktop.
Save nave91/8c580ce14d5a84684dc5fd0d29204461 to your computer and use it in GitHub Desktop.
Clean way to record start and end of sub-graph of airflow dag
# MIT License
# Copyright (c) 2019 Bellhops Inc.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import logging
from datetime import datetime
from airflow.hooks.postgres_hook import PostgresHook
from airflow.operators.python_operator import PythonOperator
from airflow import utils as airflow_utils
class Status(object):
def __init__(self, status_conn_id, status_schema_name, status_table_name, run_time):
self.table_name = status_table_name
self.conn_id = status_conn_id
self.schema_name = status_schema_name
self.run_time = run_time
self.hook = PostgresHook(postgres_conn_id=status_conn_id)
self.engine = self.hook.get_sqlalchemy_engine()
self.initialize()
self.mark_old_runs_failed()
def table_exists(self):
return self.engine.has_table(self.table_name, schema=self.schema_name)
def run_time_exists(self):
sql = "SELECT COUNT(*) FROM {schema_name}.{table_name} where run_time='{run_time}'".format(
schema_name=self.schema_name,
table_name=self.table_name,
run_time=self.run_time
)
result = self.run_sql(sql)
count = result.fetchone()[0]
if count > 0:
return True
else:
return False
def initialize(self):
if not self.table_exists():
sql = 'CREATE TABLE IF NOT EXISTS {schema_name}.{table_name}' \
'(id SERIAL,' \
'status varchar(1000),' \
'run_time TIMESTAMP,' \
'start_time TIMESTAMP,' \
'end_time TIMESTAMP, ' \
'UNIQUE (run_time, status))'.format(schema_name=self.schema_name,
table_name=self.table_name)
logging.info("Running {sql}".format(sql=sql))
self.engine.execute(sql)
else:
logging.info("Status table already present")
def run_sql(self, sql):
logging.info("Running {sql}".format(sql=sql))
result = self.engine.execute(sql)
return result
def mark_old_runs_failed(self):
logging.info("Marking old runs as failed.")
end_time = datetime.now()
sql = "UPDATE {schema_name}.{table_name} SET " \
"status = '{status}', " \
"end_time = '{end_time}' " \
"WHERE run_time < '{run_time}' and status != 'ENDED'" \
"".format(
schema_name=self.schema_name,
table_name=self.table_name,
status='FAILED',
run_time=self.run_time,
end_time=end_time
)
self.run_sql(sql)
def update(self, status):
if self.run_time_exists():
if status == 'ENDED':
end_time = datetime.now()
sql = "UPDATE {schema_name}.{table_name} SET " \
"status = '{status}', " \
"end_time = '{end_time}' " \
"WHERE run_time = '{run_time}' " \
"".format(
schema_name=self.schema_name,
table_name=self.table_name,
status=status,
run_time=self.run_time,
end_time=end_time
)
self.run_sql(sql)
else:
start_time = datetime.now()
sql = "INSERT INTO {schema_name}.{table_name}(status, run_time, start_time) VALUES(" \
"'{status}'," \
"'{run_time}'," \
"'{start_time}')" \
"".format(
schema_name=self.schema_name,
table_name=self.table_name,
status=status,
run_time=self.run_time,
start_time=start_time,
)
if status == 'STARTED':
self.run_sql(sql)
else:
raise ValueError("Error updating status. Trying to update status for non-existent ETL run. {sql}".format(sql=sql))
class StatusOperator(PythonOperator):
@airflow_utils.apply_defaults
def __init__(self, dag, task_id, status_config, current_status, **kwargs):
self.status_conn_id = status_config['status_conn_id']
self.status_schema_name = status_config['status_schema_name']
self.status_table_name = status_config['status_table_name']
self.current_status = current_status
super(StatusOperator, self).__init__(
dag=dag,
task_id=task_id,
python_callable=self.status_update,
op_kwargs={
'current_status': current_status
},
**kwargs
)
def status_update(self, current_status, **kwargs):
run_time = kwargs.get('execution_date')
status = Status(status_conn_id=self.status_conn_id,
status_schema_name=self.status_schema_name,
status_table_name=self.status_table_name,
run_time=run_time)
status.update(current_status)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment