Skip to content

Instantly share code, notes, and snippets.

@EdwardJRoss
Created March 16, 2020 12:09
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save EdwardJRoss/66561eb91049d9838db71403bd07c950 to your computer and use it in GitHub Desktop.
Save EdwardJRoss/66561eb91049d9838db71403bd07c950 to your computer and use it in GitHub Desktop.
Showing different ways to query data from Athena
#!/usr/bin/env python
"""Exploring different ways to fetch data from Athena"""
from decimal import Decimal
import codecs
import datetime
import random
import string
from urllib.parse import urlparse
import pytz
from boto3.session import Session
from pyathena import connect
import fastavro
################################################################################
## S3 Functions
################################################################################
def get_bucket_key(s3_path):
"""Returns bucket name, key from s3_path of form s3://bucket/key"""
url = urlparse(s3_path)
if url.scheme != 's3':
raise ValueError(f'Unexpected scheme {url.scheme} in {s3_path}; expected s3')
return url.netloc, url.path.lstrip('/')
def s3_stream(s3_location):
"""Returns a stream of data from s3_location"""
# TODO: Pass Session/client arguments?
s3_client = Session().client('s3')
bucket_name, bucket_key = get_bucket_key(s3_location)
obj = s3_client.get_object(Bucket=bucket_name, Key=bucket_key)
# TODO: Chunk size?
return obj["Body"]
# Assumes have a "sandbox" schema that can be written to
def _temp_table_name(name_length=10, name_prefix="sandbox.temp_"):
"""Generate a random string of fixed length """
letters = string.ascii_lowercase
return name_prefix + "".join(random.choices(letters, k=name_length))
################################################################################
## Athena CSV Parser
################################################################################
def _athena_iso8601_date(s):
return datetime.datetime.strptime(s, "%Y-%m-%d").date()
def _athena_iso8601_datetime(s):
return datetime.datetime.strptime(s, "%Y-%m-%d %H:%M:%S.%f")
def _athena_binary(s):
return codecs.decode(s.replace(" ", ""), "hex")
def _athena_decimal(s):
return Decimal(s)
_TYPE_MAPPINGS = {
"boolean": bool,
"real": float,
"float": float,
"double": float,
"tinyint": int,
"smallint": int,
"integer": int,
"bigint": int,
"decimal": _athena_decimal,
"char": str,
"varchar": str,
"array": str, # Complex types to str
"row": str, # Complex types to str
"varbinary": _athena_binary,
"map": str, # Complex types to str
"date": _athena_iso8601_date,
"timestamp": _athena_iso8601_datetime,
"unknown": str,
}
def parse_athena_csv(lines, types):
"""Parse a CSV output by Athena with types from metadata.
The CSV query results from Athena are fully quoted, except for nulls which
are unquoted. Neither Python's inbuilt CSV reader or Pandas can distinguish
the two cases so we roll our own CSV reader.
"""
rows = _athena_parse_csv(lines)
type_mappers = [_TYPE_MAPPINGS[dtype] for dtype in types]
try:
header = next(rows)
except StopIteration:
raise ValueError("Can't parse header line in CSV")
if len(types) != len(header):
raise ValueError(f"Have header {len(header)} fields, but {len(types)} types")
for row in rows:
if len(header) != len(row):
raise ValueError(f"Got {len(row)} fields, expected {len(header)}")
values = [
mapper(v) if v is not None else v for mapper, v in zip(type_mappers, row)
]
yield dict(zip(header, values))
_QUOTE = '"'
_ENDLINE = "\n"
_SEP = ","
def _athena_parse_csv(lines):
"""Parse a CSV output by Athena
Returns nulls for unquoted fields
"""
in_quote = False
paired_quote = False
chomp = False
history = ""
last_idx = 0
ans = []
for line in lines:
for idx, char in enumerate(line):
if not in_quote:
assert char in (
_SEP,
_ENDLINE,
_QUOTE,
), f"Unexpected character {char} outside of field"
# Read in nulls; unquoted fields
if char in (_SEP, _ENDLINE) and not chomp:
ans.append(None)
if char == _QUOTE:
in_quote = True
last_idx = idx + 1
elif char == _ENDLINE:
yield ans
ans = []
chomp = False
else:
if char == _QUOTE:
if paired_quote:
paired_quote = False
elif line[idx + 1] == _QUOTE:
paired_quote = True
else:
data = history + line[last_idx:idx]
data = data.replace('""', '"')
ans.append(data)
history = ""
in_quote = False
chomp = True
elif char == _ENDLINE:
history += line[last_idx : idx + 1]
last_idx = 0
assert not (history or ans), f"Leftover data: {history}, {ans}"
################################################################################
## Athena Helper Functions
################################################################################
def create_table_as(cursor, table, sql, parameters, with_data=True, **properties):
"""Create table as the result of sql using parameters.
Properties are passed into the CTAS query as per the Athena Documentation.
https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html
If with_data is False a new empty table with the same schema is created.
For example create_table_as("sandbox.test", "select 1", format='parquet')
"""
def _format(x):
if isinstance(x, str):
return f"'{x}'"
if isinstance(x, Iterable):
return f"ARRAY[{', '.join([_format(xi) for xi in x])}]"
return str(x)
if properties:
property_str = (
"WITH ( "
+ ",".join(f"{k} = {_format(v)}" for k, v in properties.items())
+ " )"
)
else:
property_str = ""
ctas = f"""CREATE TABLE {table} {property_str} AS ( {sql} ) WITH {'' if with_data else 'NO'} DATA"""
cursor.execute(ctas, parameters)
def drop_table(cursor, table: str) -> None:
"""Drop table if it exists"""
cursor.execute(f"DROP TABLE IF EXISTS {table}")
# Could delete S3 data here
def location_table(cursor, table: str):
"""Returns the external S3 paths of the data in table"""
cursor.execute(f'select distinct "$path" from {table}')
rows = cursor.fetchall()
paths = [row[0] for row in rows]
return paths
################################################################################
## Query Athena
################################################################################
def query_direct(cursor, sql, parameters=None):
"""Execute query using cursor and parameters, directly"""
cursor.execute(sql, parameters)
columns = [desc[0] for desc in cursor.description]
while True:
row = cursor.fetchone()
if row is None:
break
yield dict(zip(columns, row))
def query(cursor, sql, parameters=None):
"""Execute query using cursor and parameters, using CSV trick"""
cursor.execute(sql, parameters)
raw_stream = s3_stream(cursor.output_location)
# This comes from the .metadata file on S3
types = [x[1] for x in cursor.description]
# Stream the data
data_stream = codecs.getreader("utf-8")(raw_stream)
return parse_athena_csv(data_stream, types)
def query_avro(cursor, sql, parameters=None):
"""Execute query using cursor and parameters via avro
Query must have all columns named and valid types as for a CTAS statement"""
table = _temp_table_name()
try:
create_table_as(cursor, table, sql, parameters, format="AVRO")
for s3_location in location_table(cursor, table):
stream = s3_stream(s3_location)
# Note: We lose the schema here
for row in fastavro.reader(stream):
yield row
finally:
drop_table(cursor, table)
def test():
cursor = connect().cursor()
print("Test parameter conversion")
params = {
"bool": True,
"int": 4300000000,
"float": 1e0,
"date": datetime.date(2018, 1, 1),
"datetime": datetime.datetime(2008, 9, 15, 3, 4, 5, 324000),
"str": "a🔥",
"null": None,
}
sql = "select " + ", ".join(f'%({p})s as "{p}"' for p in params)
q = list(query(cursor, sql, params))
assert q[0] == params
q = list(query_direct(cursor, sql, params))
assert q[0] == params
# sql_named = sql.replace('%(null)s', 'cast(%(null)s as varchar)')
# q = list(query_avro(cursor, sql_named, params))
# # Fails because of https://github.com/laughingman7743/PyAthena/issues/126
# assert q[0] == params
print("Test basic types")
sql = """select TRUE as "bool",
4300000000 as "int",
1e0 as "float",
DECIMAL '0.1' as "decimal",
to_utf8('ab') as "binary",
'a🔥' as "string",
DATE '2018-01-01' as "date",
TIMESTAMP '2008-09-15 03:04:05.324' as "timestamp",
cast(NULL as varchar) as "null"
"""
result = {
'bool': True,
'int': 4300000000,
'float': 1e0,
'decimal': Decimal("0.1"),
'binary': b"ab",
'string': "a🔥",
'date': datetime.date(2018, 1, 1),
'timestamp': datetime.datetime(2008, 9, 15, 3, 4, 5, 324000),
'null': None,
}
q = list(query_direct(cursor, sql))
assert q[0] == result
q = list(query(cursor, sql))
assert q[0] == result
q = list(query_avro(cursor, sql))
# Avro includes timezone information, strip it away for test
q[0]['timestamp'] = q[0]['timestamp'].replace(tzinfo=None)
assert q[0] == result
print("Test newlines and nulls")
sql = """SELECT '\n', '\n\n', 'a\nb', '\n\n\n',
',', '\t', '"', '\\', '\\n',
'', ' ', 'N/A', 'NULL', '''', NULL"""
result = ("\n", "\n\n", "a\nb", "\n\n\n", ",", "\t", '"', "\\", "\\n",
"", " ", "N/A", "NULL", "'", None,)
q = list(query_direct(cursor, sql))
assert tuple(q[0].values()) == result
q = list(query(cursor, sql))
assert tuple(q[0].values()) == result
print("Test compound types")
sql = "select ARRAY[1, 2, 3] as array, CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)) as row, MAP(ARRAY[cast('foo' as varchar), cast('bar' as varchar)], ARRAY[1, 2]) as map"
result_str = {
'array': '[1, 2, 3]',
'row': '{x=1, y=2.0}',
'map': '{bar=2, foo=1}',
}
result = {
'array': [1, 2, 3],
'row': {'x': 1, 'y': 2.0},
'map': {'bar': 2, 'foo': 1},
}
q = list(query_direct(cursor, sql))
assert q[0] == result_str
q = list(query(cursor, sql))
assert q[0] == result_str
q = list(query_avro(cursor, sql))
assert q[0] == result
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment