Skip to content

Instantly share code, notes, and snippets.

@DandikUnited
Last active April 3, 2019 00:18
Show Gist options
  • Save DandikUnited/709dac6ce514df0ebc6af5ede4002d65 to your computer and use it in GitHub Desktop.
Save DandikUnited/709dac6ce514df0ebc6af5ede4002d65 to your computer and use it in GitHub Desktop.
from airflow.exceptions import AirflowException
from airflow.contrib.hooks.aws_hook import AwsHook
from io import StringIO
import time
import logging
TYPES_MAP = {
"integer": "INT",
"boolean": "BOOLEAN",
"bigint": "BIGINT",
"smallint": "INT",
"numeric": "DOUBLE"
}
ATHENA_KEYWORDS = ['function', 'location']
class AthenaQueryError(Exception):
pass
class AthenaHook(AwsHook):
"""
Interact with AWS Athena, using the boto3 library.
The main funciton of this hook is running queries on athena, and it will
page through the results and return them to a provided function batch by batch.
"""
def get_conn(self):
return self.get_client_type('athena')
def execute_query(self, query, query_loc, database, result_batch_lambda=None):
"""Wraps the process of running an Athena query into one method.
1) kicks off query execution
2) polls to see when query has completed
3) extracts query result rows a batch at a time from API
4) passes each batch to the `result_batch_lambda` function, which should be prepared to process
multiple sets of rows for a single query
"""
client = self.get_conn()
qid = client.start_query_execution(
QueryString=query,
QueryExecutionContext={'Database': database},
ResultConfiguration={'OutputLocation': query_loc}
)['QueryExecutionId']
query_running = True
while query_running:
q_resp = client.get_query_execution(QueryExecutionId=qid)
q_stat = q_resp['QueryExecution']['Status']
if q_stat['State'] in ['QUEUED', 'RUNNING']:
time.sleep(10)
else:
query_running = False
if q_stat['State'] != 'SUCCEEDED':
logging.error("QUERY FAILED!")
logging.error(q_resp)
raise ValueError("Athena Query Failed")
response = client.get_query_results(QueryExecutionId=qid)
if result_batch_lambda:
while 'NextToken' in response:
result_batch_lambda(response['ResultSet']['Rows'])
response = client.get_query_results(QueryExecutionId=qid, NextToken=response['NextToken'])
result_batch_lambda(response['ResultSet']['Rows'])
def construct_athena_drop_query(self, table_name):
return "DROP TABLE IF EXISTS %s;" % (table_name,)
def to_athena_column(self, schema_row):
row_parts = schema_row.split(",")
name = row_parts[0]
output_type = "STRING"
datatype = row_parts[1]
if datatype in TYPES_MAP:
output_type = TYPES_MAP[datatype]
if name in ATHENA_KEYWORDS:
name = "%s_column" % ((name))
return "%s %s" % (name, output_type)
def construct_athena_create_query(self, table_name, schema_body, location_string, partition_by=None):
query_string = StringIO()
query_string.write("CREATE EXTERNAL TABLE IF NOT EXISTS %s ( \n" % (table_name,))
schema_rows = schema_body.split("\n")
columns = [self.to_athena_column(sr) for sr in schema_rows if sr != ""]
query_string.write(",\n ".join(columns))
query_string.write("\n)\n")
if partition_by is not None:
query_string.write("PARTITIONED BY (\n")
query_string.write("%s string\n" % (partition_by,))
query_string.write(")\n")
query_string.write(" STORED AS PARQUET\n")
query_string.write("LOCATION '%s'\n" % (location_string,))
query_string.write("TBLPROPERTIES (\"parquet.compress\"=\"SNAPPY\");")
return query_string.getvalue()
def execute_athena_ddl(self, athena_database, s3_query_bucket, query):
athena_client = self.get_conn()
response = athena_client.start_query_execution(
QueryString=query,
QueryExecutionContext={'Database': athena_database},
ResultConfiguration={
'OutputLocation': ("s3://%s/ATHENA_QUERY_RESULTS/" % (s3_query_bucket,))
}
)
q_id = response['QueryExecutionId']
q_finished = False
while(not q_finished):
query_response = athena_client.get_query_execution(QueryExecutionId=q_id)['QueryExecution']
query_status = query_response['Status']['State']
if query_status not in ['QUEUED', 'RUNNING']:
q_finished = True
if query_status == 'FAILED':
logging.info("FAILED TO EXECUTE DDL: %s" % (query_response['Status']['StateChangeReason'],))
raise AthenaQueryError(query_response['Status']['StateChangeReason'])
else:
time.sleep(3)
return True
def create_table(
self,
table_name,
schema_body,
location_string,
athena_database,
s3_query_bucket,
partition_by=None):
"""Wraps the process of creating an Athena table over an S3 location (assumes Parquet formatting)
1) formulates Drop & Create queries
2) Execute drop query (wait for success)
3) Execute create query (wait for success)
This assumes that the aws connection described by this hook can access both
the S3 bucket where the data lies and the Athena permissions necessary
to impact the Hive metastore.
the "schema_body" is expected to be a tabular file (think csv, it depends on commas) detailing
the Postgres-like schema of the data (the translation to athena fields
assumes postgres data types in the schema definition). columns are
"column name", "data type", and "nullable". For example:
id, integer, NO
name, varchar, YES
partition_by is an optional parameter for if your parquet
files are already partitioned in the prescribed manner:
https://docs.aws.amazon.com/athena/latest/ug/partitions.html#scenario-1-data-already-partitioned-and-stored-on-s3-in-hive-format
"""
drop_q = self.construct_athena_drop_query(table_name)
create_q = self.construct_athena_create_query(
table_name, schema_body, location_string, partition_by=partition_by)
self.execute_athena_ddl(athena_database, s3_query_bucket, drop_q)
self.execute_athena_ddl(athena_database, s3_query_bucket, create_q)
if partition_by is not None:
repair_q = "MSCK REPAIR TABLE %s;" % (table_name,)
self.execute_athena_ddl(athena_database, s3_query_bucket, repair_q)
logging.info("SUCCESSFULLY CREATED ATHENA TABLE: %s" % (table_name,))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment