Skip to content

Instantly share code, notes, and snippets.

@r39132
Created March 15, 2016 04:13
Show Gist options
  • Save r39132/2bfba8e1c18558b1e3bc to your computer and use it in GitHub Desktop.
Save r39132/2bfba8e1c18558b1e3bc to your computer and use it in GitHub Desktop.
Reload_Data
"""
Automation
This script will reload data (e.g. re-ingest) data into a database
"""
# ## Imports
import getopt
import logging
import os
import psycopg2
import sys
from datetime import date, datetime, time, timedelta
from ep_reload_data_utils import *
from subprocess import Popen, PIPE, STDOUT
from airflow.models import Variable
# Algo
# if -i
# 1. copy files in s3
# 2. delete parquet
# 3. call run_ingest
# 0. Wait for any previous runs to complete
# 1. Delete data from table
# 2. Purge DLQ
# 3. Load Data
# a. clear and backfill
# b. Generate pre-aggregates
# c. Verify no messages on the DLQ
# d. Verify data added to the message table
# 4. Send "success" email if run successful
# Set up logging
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging.getLogger('reload_data')
logger.setLevel(logging.DEBUG)
# Constants
# ALL CONSTANTS HIDDEN
PID_FILE = "/tmp/reload_data.pid"
spinup_separate_cluster_for_reload = str_to_bool(Variable.get('ep_pipeline_emr_reload_spinup').strip())
# Ensure that only one version of this program is running!
def can_run(pid, pidfile):
if os.path.isfile(pidfile):
print "%s already exists, exiting" % pidfile
return False
else:
file(pidfile, 'w').write(pid)
return True
# Purge DLQ
def purge_DLQ(env_fabric, start_date):
'''
Purge DLQ
'''
logger.info("purge_dlq in {} for -s {}".format(env_fabric,
start_date))
# Convert dates into strings
start_date_string = start_date.strftime(DATE_FORMAT)
# Call purge DLQ
purge_DLQ_task_name = 'purge_DLQ'
command = "airflow test ep_reload_data {} {}".format(purge_DLQ_task_name, start_date_string)
_call_subprocess_simple(command)
return True
# Get Env-specific conf values
def get_conf_value_for_env(env_fabric,
prod_conf=None,
tad_conf=None,
stage_conf=None):
if env_fabric == ENV_FABRIC_PROD_VALUE:
return prod_conf
elif env_fabric == ENV_FABRIC_STAGE_VALUE:
return stage_conf
else:
return tad_conf
# Delete RDA, RIA, and Message records in the table for the target date range
def delete_db_data(env_fabric, start_date, end_date, org_ids):
'''
Delete data from the RDA, RIA, and Message tables
'''
logger.info("delete_db_data in {} for -s {} -e {} for org_ids {}".format(env_fabric,
start_date,
end_date,
org_ids))
# Set the connection string based on the environment
conn_string = get_conf_value_for_env(env_fabric,
prod_conf=PROD_DB_CONN_STRING,
tad_conf=TAD_DB_CONN_STRING,
stage_conf=STAGE_DB_CONN_STRING)
# Format and compute dates for the time bounds
start_date_string = start_date.strftime(DATE_FORMAT)
end_date_for_db_query = end_date + timedelta(days=1)
end_date_string = end_date_for_db_query.strftime(DATE_FORMAT)
# Delete data
delete_db_data_util(conn_string,
start_date_string,
end_date_string,
org_ids,
vacuum_analyze=True)
return True
# Verify that message records are present for each date in the range
def verify_data_in_db(env_fabric, start_date, end_date):
'''
Verify that data was successfully loaded into the tables by checking the messages were added
for each date in the range.
'''
logger.info("verify_data_in_db in {} for -s {} -e {}".format(env_fabric, start_date, end_date))
# Set the connection string based on the environment
conn_string = get_conf_value_for_env(env_fabric,
prod_conf=PROD_DB_CONN_STRING,
tad_conf=TAD_DB_CONN_STRING,
stage_conf=STAGE_DB_CONN_STRING)
db_conn = psycopg2.connect(conn_string)
logger.info("----- Successfully Connected to database {}".format(conn_string))
cursor = db_conn.cursor()
# Generate the date-bound queries
start_date_string = start_date.strftime(DATE_FORMAT)
end_date_for_db_query = end_date + timedelta(days=1)
end_date_string = end_date_for_db_query.strftime(DATE_FORMAT)
# Execute the query
MESSAGE_CHECK_QUERY_BOUND = MESSAGE_CHECK_QUERY % (start_date_string, end_date_string)
logger.info("----- Executing the following query against the db : {}".format(MESSAGE_CHECK_QUERY_BOUND))
cursor.execute(MESSAGE_CHECK_QUERY_BOUND)
result = cursor.fetchone()
row_count = int(result[0])
db_conn.commit() # close the transaction
# close the cursor and connection
cursor.close()
db_conn.close()
# Loop through the dates
d = start_date
delta = timedelta(days=1)
expected_count = 0;
while d <= end_date:
expected_count = expected_count + 1
d += delta
logger.info("{} records written to the DB and expect count is {}".format(row_count, expected_count))
return row_count == expected_count
# Build models
def run_build_models(start_date):
# Build models : 3 stages : sender model, cdd, domain_rep
start_date_string = start_date.strftime(DATE_FORMAT)
# Sender Models
build_models_spark_job_name = 'build_sender_models_spark_job'
command = "airflow test ep_reload_data {} {}".format(build_models_spark_job_name, start_date_string)
_call_subprocess_simple(command)
# CDD Models
build_models_spark_job_name = 'build_cdd_models_spark_job'
command = "airflow test ep_reload_data {} {}".format(build_models_spark_job_name, start_date_string)
_call_subprocess_simple(command)
# Domain Rep Models
build_models_spark_job_name = 'build_dom_rep_models_spark_job'
command = "airflow test ep_reload_data {} {}".format(build_models_spark_job_name, start_date_string)
_call_subprocess_simple(command)
# Validate that models were built!
metadata_bucket = get_conf_value_for_env(env_fabric,
prod_conf=METADATA_BUCKET_NAME_PROD_VALUE,
tad_conf=METADATA_BUCKET_NAME_TAD_VALUE,
stage_conf=METADATA_BUCKET_NAME_STAGE_VALUE)
return verify_model_building_successful_util(REGION, metadata_bucket )
# Load data by running aggregation within the specified date range
def run_aggregation(env_fabric, start_date, end_date, disable_end_user_alerting):
'''
Load data by running aggregation within the data range
Algo :
1. Backfill aggregate task for date range
2. Wait for queue to empty
3. Run Pre-aggregation job
4. Check the DLQ
5. Check data added to the DB
'''
# Loop through the dates
run_agg_task_name = 'aggregate_data_spark_job'
d = start_date
delta = timedelta(days=1)
while d <= end_date:
start_date_string = d.strftime(DATE_FORMAT)
command = "airflow test ep_reload_data {} {}".format(run_agg_task_name, start_date_string)
_call_subprocess_simple(command)
d += delta
# Wait for the SQS queue to drain
run_wait_for_empty_queue_name = 'wait_for_empty_queue_reload'
command = "airflow test ep_reload_data {} {}".format(run_wait_for_empty_queue_name, start_date_string)
_call_subprocess_simple(command)
run_db_message_aggregation(start_date, end_date)
if not disable_end_user_alerting:
# HOURLY : Run through all of the relevant hours, issuing a customer alert for each!
enqueue_alerts_job_name = 'enqueue_alerting_jobs'
d = start_date
delta = timedelta(hours=1)
while d <= end_date:
start_date_string = d.strftime('%Y-%m-%dT%H:%M:%S')
command = "airflow test ep_telemetry_v2 {} {}".format(enqueue_alerts_job_name,
start_date_string)
_call_subprocess_simple(command)
d += delta
# Check the DLQ for messages
SQS_queue_name = get_conf_value_for_env(env_fabric,
prod_conf=SQS_Q_NAME_PROD_VALUE,
tad_conf=SQS_Q_NAME_TAD_VALUE,
stage_conf=SQS_Q_NAME_STAGE_VALUE)
any_dlq_messages = check_for_dlq_messages(REGION, SQS_queue_name)
print " any_dlq_messages = {}".format(any_dlq_messages)
# Return tuple (ANY_DLQ_ENTRIES, NO_DATA_MISSING_IN_DB_FOR_SOME_DATES)
return (any_dlq_messages, verify_data_in_db(env_fabric, start_date, end_date));
def run_db_message_aggregation(start_date, end_date):
# Run the pre-aggregation job
run_pre_agg_job_name = 'aggregate_db_message_job'
d = start_date
delta = timedelta(days=1)
success = True
while d <= end_date:
start_date_string = d.strftime(DATE_FORMAT)
command = "airflow test ep_reload_data {} {}".format(run_pre_agg_job_name,
start_date_string)
return_code = _call_subprocess_simple(command)
success = return_code == 0 and success
d += delta
return success
def clear_parquet(env_fabric, start_date, end_date):
'''
Given a time range, this function will delete the parquet files within that range!
'''
# Set the spark_master_ssh_access_key and spark_master_ip based on the environment
spark_master_ssh_access_key = get_conf_value_for_env(env_fabric,
prod_conf=SPARK_MASTER_SSH_KEY_PROD_VALUE,
tad_conf=SPARK_MASTER_SSH_KEY_TAD_VALUE,
stage_conf=SPARK_MASTER_SSH_KEY_STAGE_VALUE)
spark_master_ip = get_conf_value_for_env(env_fabric,
prod_conf=SPARK_MASTER_IP_PROD_VALUE,
tad_conf=SPARK_MASTER_IP_TAD_VALUE,
stage_conf=SPARK_MASTER_IP_STAGE_VALUE)
logger.info("Running clear_parquet with start_date=%s, end_date=%s in %s against spark cluster=%s using key=%s" % (start_date,
end_date,
env_fabric,
spark_master_ip,
spark_master_ssh_access_key))
# Iterate through a range of days
d = start_date
delta = timedelta(days=1)
while d <= end_date:
d_as_epoch = d.strftime('%s')
# Check for the existence of Parquet
command = "ssh -o 'StrictHostKeyChecking no' -i ~/.ssh/{} root@{} '. /root/.bash_profile; /root/ephemeral-hdfs/bin/hadoop fs -ls /agari-parquet-dev/{}'".format(spark_master_ssh_access_key,
spark_master_ip,
d_as_epoch)
ret_code = _call_subprocess_simple(command, allow_error=True)
if ret_code == 0:
# Remove the Parquet files
logger.info("Running clear_parquet for d=%s" % (d_as_epoch))
command = "ssh -o 'StrictHostKeyChecking no' -i ~/.ssh/{} root@{} '. /root/.bash_profile; /root/ephemeral-hdfs/bin/hadoop fs -rmr -skipTrash /agari-parquet-dev/{}'".format(spark_master_ssh_access_key,
spark_master_ip,
d_as_epoch)
_call_subprocess_simple(command, suppress_logging=True)
else:
logger.info("Skipping Parquet removal as there is no parquet for d=%s" % (d_as_epoch))
d += delta
def reingest_collector_files(env_fabric, start_date, end_date):
# Set the collector_consumer and collector_ingest bucket names based on the environment
col_ingest_bucket_name = get_conf_value_for_env(env_fabric,
prod_conf=COL_INGEST_BUCKET_NAME_PROD_VALUE,
tad_conf=COL_INGEST_BUCKET_NAME_TAD_VALUE,
stage_conf=COL_INGEST_BUCKET_NAME_STAGE_VALUE)
col_consumer_bucket_name = get_conf_value_for_env(env_fabric,
prod_conf=COL_CONSUMER_BUCKET_NAME_PROD_VALUE,
tad_conf=COL_CONSUMER_BUCKET_NAME_TAD_VALUE,
stage_conf=COL_CONSUMER_BUCKET_NAME_STAGE_VALUE)
# Loop through the dates
d = start_date
delta = timedelta(days=1)
while d <= end_date:
d_string = d.strftime(COLLECTOR_FILE_DATE_FORMAT)
command = "aws s3 cp --recursive --exclude '*' --include '*{}*.avro' --exclude '*.gz' s3://{}/uploads s3://{}/uploads/".format(d_string,
col_ingest_bucket_name,
col_consumer_bucket_name)
_call_subprocess_simple(command)
d += delta
# Generate Parquet using any files that happen to be in the collector-consumer bucket
start_date_string = start_date.strftime(DATE_FORMAT)
run_ingest_task_name = 'generate_new_parquet_files_spark_job'
command = "airflow test ep_reload_data {} {}".format(run_ingest_task_name, start_date_string)
_call_subprocess_simple(command)
# Helper method to handle subprocess failures
def _call_subprocess_complex(command_as_list, stdin_input):
command_as_string = ' '.join(command_as_list)
logger.info("Running {}".format(command_as_string))
p = Popen(command_as_list, stdin=PIPE)
p_stdout = p.communicate(input=stdin_input)[0]
logger.info(p_stdout)
error_message = "FAILURE : %s" % command_as_string
ret_code = p.returncode
if ret_code != 0:
logger.critical(error_message)
exit(2)
# Helper method to handle subprocess failures
def _call_subprocess_simple(command, suppress_logging=False, allow_error=False, tries=2):
if not suppress_logging:
logging_message = "Running : %s" % command
logger.info(logging_message)
# Attempt a "Try" number of times
attempts = 0
ret_code = -9
while attempts < tries and ret_code != 0:
logging_message = "%s Attempt at Running : %s" % (attempts, command)
ret_code = os.system(command)
attempts = attempts + 1
# If we allow errors, continue
error_message = "FAILURE on %s Attempts at Running : %s" % (attempts, command)
if allow_error == True:
return ret_code
# Otherwise, log a critical error
if ret_code != 0:
logger.critical(error_message)
exit(2)
# Always return the ret_code
return ret_code
# Define the usage
def usage():
print "\n USAGE : python reload_data.py -s <start date> -e <end date> -f <environment> [-m] [-q] \n\n \
\t # -s : start_date (inclusive, YYYY-MM-DD, required) \n \
\t # -e : end_date (inclusive, YYYY-MM-DD, required) \n \
\t # -f : env or fabric (TAD, STAGE, or PROD, case-insensitive) \n \
\t # -m : build models (optional) \n \
\t # -q : disable end-user alerting (optional) \n \
\t # -a : only aggregate db messages (optional) \n \
"
def validate_program_args(env_fabric, start_date, end_date):
error = False
# Don't allow a future date
now_utc = datetime.utcnow().date()
if end_date > now_utc:
logger.info("An end date=%s in the future (greater than now=%s) is not allowed!" % (end_date, now_utc))
error = True
# Start date can't succeed end date
if start_date > end_date:
logger.info("Start date=%s cannot be greater than end date=%s" % (start_date, end_date))
error = True
# Fabric is required
if not env_fabric:
logger.info("env_fabric is required!")
error = True
else:
if env_fabric not in (ENV_FABRIC_TAD_VALUE, ENV_FABRIC_PROD_VALUE, ENV_FABRIC_STAGE_VALUE):
logger.info("env_fabric=%s does not match %s, %s, or %s!" % (env_fabric,
ENV_FABRIC_TAD_VALUE,
ENV_FABRIC_PROD_VALUE,
ENV_FABRIC_STAGE_VALUE))
error = True
# Exit if error
if error:
sys.exit(2)
# Define the main entry point
def main(argv):
'''
There are 4 arguments to the program : refer to usage()
* If the optional ingest flag is specified, clear the parquet and copy over collector files.
* Always, delete targeted records from the DB, clear aggregation in airflow, and rerun in airflow
'''
# Set some default values for optional arguments
env_fabric = None
build_models = None
validate_only = None
disable_end_user_alerting = False
msg_agg_only = False
try:
opts, remaining_args = getopt.getopt(argv, "aqvms:e:f:")
except getopt.GetoptError:
usage()
sys.exit(2)
for o, a in opts:
if o == "-s":
start_date = datetime.strptime(a, DATE_FORMAT).date()
start_ds = a
elif o == "-e":
end_date = datetime.strptime(a, DATE_FORMAT).date()
end_ds = a
elif o == "-f":
env_fabric = a.lower() # tad or prod
elif o == "-m":
build_models = True
elif o == "-q":
disable_end_user_alerting = True
elif o == "-v":
validate_only = True
elif o == "-a":
msg_agg_only = True
else:
assert False, "unhandled option"
# Validate params
validate_program_args(env_fabric, start_date,
end_date)
start_date_string = start_date.strftime(DATE_FORMAT)
# Ensure that only one instance of this program is running
do_run = False
try:
pid = str(os.getpid())
do_run = can_run(pid, PID_FILE)
if not do_run:
# Exit, but don't clean up the pidfile
exit(2)
if import_ep_autopause_reload:
pause_val = _call_subprocess_simple('airflow pause ep_telemetry_v2', tries=1)
if pause_val != 0:
print "Unable to pause telemetry dag - Exiting."
exit(2)
if spinup_separate_cluster_for_reload:
run_spinup_new_cluster_name = 'spinup_new_cluster'
command = "airflow test ep_reload_data {} {}".format(run_spinup_new_cluster_name, start_date_string)
_call_subprocess_simple(command)
if validate_only:
success = True
else:
if build_models:
if not run_build_models(start_date):
logger.info("Model Building Failed!!")
exit(2)
# Load Data
# Wait for the SQS queue to drain from a previous run
run_wait_for_empty_queue_name = 'wait_for_empty_queue_reload'
command = "airflow test ep_reload_data {} {}".format(run_wait_for_empty_queue_name, start_date_string)
_call_subprocess_simple(command)
if msg_agg_only:
success = run_db_message_aggregation(start_date, end_date)
any_dlq_messages = False
all_data_in_db = True
else:
purge_DLQ(env_fabric, start_date)
delete_db_data(env_fabric, start_date, end_date, ORG_IDS)
# Call aggregation, figure out if we have DLQ messages and if all dates have some data in db
any_dlq_messages, all_data_in_db = run_aggregation(env_fabric, start_date, end_date, disable_end_user_alerting)
success = not any_dlq_messages and all_data_in_db
if success:
#only show optional args in success email
email_opts = [ opt for opt in opts if opt[0] not in ('-s','-e','-f') ]
generate_successful_email_util(start_ds, end_ds, reload_data=True,
reload_opts=email_opts,
all_data_in_db=True,
any_dlq_messages=False)
print "\n\n------------------------------------------"
print " SUCCESS "
print "------------------------------------------"
exit(0)
else:
#only show optional args in success email
email_opts = [ opt for opt in opts if opt[0] not in ('-s','-e','-f') ]
generate_successful_email_util(start_ds, end_ds, reload_data=True,
reload_opts=email_opts,
all_data_in_db=all_data_in_db,
any_dlq_messages=any_dlq_messages)
print "\n\n------------------------------------------"
print " FAILURE "
print "------------------------------------------"
exit(-1)
finally:
if do_run:
if spinup_separate_cluster_for_reload:
run_terminate_cluster_name = 'terminate_cluster'
command = "airflow test ep_reload_data {} {}".format(run_terminate_cluster_name, start_date_string)
_call_subprocess_simple(command)
if import_ep_autopause_reload:
unpause_val = _call_subprocess_simple('airflow unpause ep_telemetry_v2', tries=1)
if unpause_val != 0:
print 'Unpause telemetry dag failed, re enable through Airflow UI'
print 'Cleaning up pidfile for pid=%s. Goodbye, world!' % (pid)
os.unlink(PID_FILE)
if __name__ == "__main__":
main(sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment