Created
March 16, 2020 12:09
-
-
Save EdwardJRoss/66561eb91049d9838db71403bd07c950 to your computer and use it in GitHub Desktop.
Showing different ways to query data from Athena
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Exploring different ways to fetch data from Athena""" | |
from decimal import Decimal | |
import codecs | |
import datetime | |
import random | |
import string | |
from urllib.parse import urlparse | |
import pytz | |
from boto3.session import Session | |
from pyathena import connect | |
import fastavro | |
################################################################################ | |
## S3 Functions | |
################################################################################ | |
def get_bucket_key(s3_path): | |
"""Returns bucket name, key from s3_path of form s3://bucket/key""" | |
url = urlparse(s3_path) | |
if url.scheme != 's3': | |
raise ValueError(f'Unexpected scheme {url.scheme} in {s3_path}; expected s3') | |
return url.netloc, url.path.lstrip('/') | |
def s3_stream(s3_location): | |
"""Returns a stream of data from s3_location""" | |
# TODO: Pass Session/client arguments? | |
s3_client = Session().client('s3') | |
bucket_name, bucket_key = get_bucket_key(s3_location) | |
obj = s3_client.get_object(Bucket=bucket_name, Key=bucket_key) | |
# TODO: Chunk size? | |
return obj["Body"] | |
# Assumes have a "sandbox" schema that can be written to | |
def _temp_table_name(name_length=10, name_prefix="sandbox.temp_"): | |
"""Generate a random string of fixed length """ | |
letters = string.ascii_lowercase | |
return name_prefix + "".join(random.choices(letters, k=name_length)) | |
################################################################################ | |
## Athena CSV Parser | |
################################################################################ | |
def _athena_iso8601_date(s): | |
return datetime.datetime.strptime(s, "%Y-%m-%d").date() | |
def _athena_iso8601_datetime(s): | |
return datetime.datetime.strptime(s, "%Y-%m-%d %H:%M:%S.%f") | |
def _athena_binary(s): | |
return codecs.decode(s.replace(" ", ""), "hex") | |
def _athena_decimal(s): | |
return Decimal(s) | |
_TYPE_MAPPINGS = { | |
"boolean": bool, | |
"real": float, | |
"float": float, | |
"double": float, | |
"tinyint": int, | |
"smallint": int, | |
"integer": int, | |
"bigint": int, | |
"decimal": _athena_decimal, | |
"char": str, | |
"varchar": str, | |
"array": str, # Complex types to str | |
"row": str, # Complex types to str | |
"varbinary": _athena_binary, | |
"map": str, # Complex types to str | |
"date": _athena_iso8601_date, | |
"timestamp": _athena_iso8601_datetime, | |
"unknown": str, | |
} | |
def parse_athena_csv(lines, types): | |
"""Parse a CSV output by Athena with types from metadata. | |
The CSV query results from Athena are fully quoted, except for nulls which | |
are unquoted. Neither Python's inbuilt CSV reader or Pandas can distinguish | |
the two cases so we roll our own CSV reader. | |
""" | |
rows = _athena_parse_csv(lines) | |
type_mappers = [_TYPE_MAPPINGS[dtype] for dtype in types] | |
try: | |
header = next(rows) | |
except StopIteration: | |
raise ValueError("Can't parse header line in CSV") | |
if len(types) != len(header): | |
raise ValueError(f"Have header {len(header)} fields, but {len(types)} types") | |
for row in rows: | |
if len(header) != len(row): | |
raise ValueError(f"Got {len(row)} fields, expected {len(header)}") | |
values = [ | |
mapper(v) if v is not None else v for mapper, v in zip(type_mappers, row) | |
] | |
yield dict(zip(header, values)) | |
_QUOTE = '"' | |
_ENDLINE = "\n" | |
_SEP = "," | |
def _athena_parse_csv(lines): | |
"""Parse a CSV output by Athena | |
Returns nulls for unquoted fields | |
""" | |
in_quote = False | |
paired_quote = False | |
chomp = False | |
history = "" | |
last_idx = 0 | |
ans = [] | |
for line in lines: | |
for idx, char in enumerate(line): | |
if not in_quote: | |
assert char in ( | |
_SEP, | |
_ENDLINE, | |
_QUOTE, | |
), f"Unexpected character {char} outside of field" | |
# Read in nulls; unquoted fields | |
if char in (_SEP, _ENDLINE) and not chomp: | |
ans.append(None) | |
if char == _QUOTE: | |
in_quote = True | |
last_idx = idx + 1 | |
elif char == _ENDLINE: | |
yield ans | |
ans = [] | |
chomp = False | |
else: | |
if char == _QUOTE: | |
if paired_quote: | |
paired_quote = False | |
elif line[idx + 1] == _QUOTE: | |
paired_quote = True | |
else: | |
data = history + line[last_idx:idx] | |
data = data.replace('""', '"') | |
ans.append(data) | |
history = "" | |
in_quote = False | |
chomp = True | |
elif char == _ENDLINE: | |
history += line[last_idx : idx + 1] | |
last_idx = 0 | |
assert not (history or ans), f"Leftover data: {history}, {ans}" | |
################################################################################ | |
## Athena Helper Functions | |
################################################################################ | |
def create_table_as(cursor, table, sql, parameters, with_data=True, **properties): | |
"""Create table as the result of sql using parameters. | |
Properties are passed into the CTAS query as per the Athena Documentation. | |
https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html | |
If with_data is False a new empty table with the same schema is created. | |
For example create_table_as("sandbox.test", "select 1", format='parquet') | |
""" | |
def _format(x): | |
if isinstance(x, str): | |
return f"'{x}'" | |
if isinstance(x, Iterable): | |
return f"ARRAY[{', '.join([_format(xi) for xi in x])}]" | |
return str(x) | |
if properties: | |
property_str = ( | |
"WITH ( " | |
+ ",".join(f"{k} = {_format(v)}" for k, v in properties.items()) | |
+ " )" | |
) | |
else: | |
property_str = "" | |
ctas = f"""CREATE TABLE {table} {property_str} AS ( {sql} ) WITH {'' if with_data else 'NO'} DATA""" | |
cursor.execute(ctas, parameters) | |
def drop_table(cursor, table: str) -> None: | |
"""Drop table if it exists""" | |
cursor.execute(f"DROP TABLE IF EXISTS {table}") | |
# Could delete S3 data here | |
def location_table(cursor, table: str): | |
"""Returns the external S3 paths of the data in table""" | |
cursor.execute(f'select distinct "$path" from {table}') | |
rows = cursor.fetchall() | |
paths = [row[0] for row in rows] | |
return paths | |
################################################################################ | |
## Query Athena | |
################################################################################ | |
def query_direct(cursor, sql, parameters=None): | |
"""Execute query using cursor and parameters, directly""" | |
cursor.execute(sql, parameters) | |
columns = [desc[0] for desc in cursor.description] | |
while True: | |
row = cursor.fetchone() | |
if row is None: | |
break | |
yield dict(zip(columns, row)) | |
def query(cursor, sql, parameters=None): | |
"""Execute query using cursor and parameters, using CSV trick""" | |
cursor.execute(sql, parameters) | |
raw_stream = s3_stream(cursor.output_location) | |
# This comes from the .metadata file on S3 | |
types = [x[1] for x in cursor.description] | |
# Stream the data | |
data_stream = codecs.getreader("utf-8")(raw_stream) | |
return parse_athena_csv(data_stream, types) | |
def query_avro(cursor, sql, parameters=None): | |
"""Execute query using cursor and parameters via avro | |
Query must have all columns named and valid types as for a CTAS statement""" | |
table = _temp_table_name() | |
try: | |
create_table_as(cursor, table, sql, parameters, format="AVRO") | |
for s3_location in location_table(cursor, table): | |
stream = s3_stream(s3_location) | |
# Note: We lose the schema here | |
for row in fastavro.reader(stream): | |
yield row | |
finally: | |
drop_table(cursor, table) | |
def test(): | |
cursor = connect().cursor() | |
print("Test parameter conversion") | |
params = { | |
"bool": True, | |
"int": 4300000000, | |
"float": 1e0, | |
"date": datetime.date(2018, 1, 1), | |
"datetime": datetime.datetime(2008, 9, 15, 3, 4, 5, 324000), | |
"str": "a🔥", | |
"null": None, | |
} | |
sql = "select " + ", ".join(f'%({p})s as "{p}"' for p in params) | |
q = list(query(cursor, sql, params)) | |
assert q[0] == params | |
q = list(query_direct(cursor, sql, params)) | |
assert q[0] == params | |
# sql_named = sql.replace('%(null)s', 'cast(%(null)s as varchar)') | |
# q = list(query_avro(cursor, sql_named, params)) | |
# # Fails because of https://github.com/laughingman7743/PyAthena/issues/126 | |
# assert q[0] == params | |
print("Test basic types") | |
sql = """select TRUE as "bool", | |
4300000000 as "int", | |
1e0 as "float", | |
DECIMAL '0.1' as "decimal", | |
to_utf8('ab') as "binary", | |
'a🔥' as "string", | |
DATE '2018-01-01' as "date", | |
TIMESTAMP '2008-09-15 03:04:05.324' as "timestamp", | |
cast(NULL as varchar) as "null" | |
""" | |
result = { | |
'bool': True, | |
'int': 4300000000, | |
'float': 1e0, | |
'decimal': Decimal("0.1"), | |
'binary': b"ab", | |
'string': "a🔥", | |
'date': datetime.date(2018, 1, 1), | |
'timestamp': datetime.datetime(2008, 9, 15, 3, 4, 5, 324000), | |
'null': None, | |
} | |
q = list(query_direct(cursor, sql)) | |
assert q[0] == result | |
q = list(query(cursor, sql)) | |
assert q[0] == result | |
q = list(query_avro(cursor, sql)) | |
# Avro includes timezone information, strip it away for test | |
q[0]['timestamp'] = q[0]['timestamp'].replace(tzinfo=None) | |
assert q[0] == result | |
print("Test newlines and nulls") | |
sql = """SELECT '\n', '\n\n', 'a\nb', '\n\n\n', | |
',', '\t', '"', '\\', '\\n', | |
'', ' ', 'N/A', 'NULL', '''', NULL""" | |
result = ("\n", "\n\n", "a\nb", "\n\n\n", ",", "\t", '"', "\\", "\\n", | |
"", " ", "N/A", "NULL", "'", None,) | |
q = list(query_direct(cursor, sql)) | |
assert tuple(q[0].values()) == result | |
q = list(query(cursor, sql)) | |
assert tuple(q[0].values()) == result | |
print("Test compound types") | |
sql = "select ARRAY[1, 2, 3] as array, CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)) as row, MAP(ARRAY[cast('foo' as varchar), cast('bar' as varchar)], ARRAY[1, 2]) as map" | |
result_str = { | |
'array': '[1, 2, 3]', | |
'row': '{x=1, y=2.0}', | |
'map': '{bar=2, foo=1}', | |
} | |
result = { | |
'array': [1, 2, 3], | |
'row': {'x': 1, 'y': 2.0}, | |
'map': {'bar': 2, 'foo': 1}, | |
} | |
q = list(query_direct(cursor, sql)) | |
assert q[0] == result_str | |
q = list(query(cursor, sql)) | |
assert q[0] == result_str | |
q = list(query_avro(cursor, sql)) | |
assert q[0] == result | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment