Last active
November 18, 2020 23:26
-
-
Save jeffmylife/e10b5f64ce9c1f852604877d16efe001 to your computer and use it in GitHub Desktop.
Transfer a large csv file on S3 to RDS Serverless through lambda function.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import json | |
import os | |
import boto3 | |
import botocore.response | |
from pprint import pprint | |
MINIMUN_REMAINING_TIME_MS = int(os.getenv('MINIMUM_REMAINING_TIME_MS') or 10000) | |
def type_mapper(mysql_type_str, value_to_map)->tuple : | |
"""Maps mysql type to the boto3 rds api type""" | |
t = mysql_type_str.upper() | |
if 'DOUBLE' in t: | |
try: | |
return "doubleValue", float(value_to_map) | |
except: | |
return "isNull", True | |
if 'VARCHAR' in t: | |
return "stringValue", str(value_to_map) | |
if "INT" in t: | |
try: | |
return "longValue", int(value_to_map) | |
except: | |
return "isNull", True | |
raise Exception(f"Unknown mysql type: {mysql_type_str}") | |
def handler(event, context): | |
# get values from event | |
bucket_name = event['bucket_name'] | |
object_key = event['object_key'] | |
resourceArn = event['resourceArn'] | |
secretArn = event['secretArn'] | |
db_name = event['db_name'] | |
table_name = event['table_name'] | |
offset = event.get('offset', 0) | |
fieldnames = event.get('fieldnames', None) | |
fieldtypes = event.get('fieldtypes', None) | |
# aws resources | |
s3_resource = boto3.resource('s3') | |
s3_object = s3_resource.Object(bucket_name=bucket_name, key=object_key) | |
rds = boto3.client('rds-data') | |
# rds schema | |
assert db_name and table_name, "must give db name and table name" | |
response = rds.execute_statement( | |
secretArn = secretArn, | |
database = db_name, | |
resourceArn = resourceArn, | |
sql = f'DESCRIBE {table_name}' | |
) | |
assert response['ResponseMetadata']['HTTPStatusCode']==200 | |
assert 'records' in response | |
records = response['records'] | |
schema = [ (record[0]['stringValue'], record[1]['stringValue']) | |
for record in records | |
] | |
# pprint(schema) | |
# init csv reader | |
bodylines = get_object_bodylines(s3_object, offset) | |
to_iter = bodylines.iter_lines() | |
if not fieldnames: | |
assert offset==0, "Fieldnames should set only at offset = 0. " | |
csv_fieldnames = next(to_iter).split(",") ## notice next skips the first line from now on | |
csv_fieldnames = list(i.strip() for i in fieldnames) | |
fieldnames = list(n for n, t in schema) | |
fieldtypes = list(t for n, t in schema) | |
# quality check | |
s1, s2 = set(csv_fieldnames), set(fieldnames) | |
if s1 != s2: | |
print(f"Warning: file field names do not match existing columns in database:\n\t{s1 ^ s2}") | |
csv_reader = csv.DictReader(to_iter, fieldnames=fieldnames) | |
def batch_insert(schema, sql_parameters): | |
fieldnames = list(n for n, t in schema) | |
sql = f''' | |
INSERT INTO {table_name} ({str(', '.join(fieldnames))}) | |
VALUES ( :{str(', :'.join(fieldnames))}) | |
''' | |
# TODO handle failure here with https://stackoverflow.com/questions/58192747/aws-aurora-serverless-communication-link-failure | |
response = rds.batch_execute_statement( | |
secretArn=secretArn, | |
database=db_name, | |
resourceArn=resourceArn, | |
sql=sql, | |
parameterSets = sql_parameters | |
) | |
return response | |
i = 0 | |
sql_parameter_sets = [] | |
for row in csv_reader: | |
## process and do work | |
entry = [] | |
for column_name, column_type in schema: | |
type_value, casted_value = type_mapper(column_type, row[column_name]) | |
# if type_value != 'stringValue' and casted_value < 0: | |
# print(f"FAILURE @ i={i} with column_name={column_name} column_type={column_type} type_value={type_value}, casted_value={casted_value}") | |
entry.append({'name':column_name, | |
'value':{type_value: casted_value}}) | |
sql_parameter_sets.append(entry) | |
# TODO: make it dependent on data size rather than number of rows; if you don't, lambda runs out of memory | |
if (i+1) % 3000 == 0: # save every %n rows | |
print("inserting %d" % len(sql_parameter_sets)) | |
batch_insert(schema, sql_parameter_sets) | |
sql_parameter_sets = [] | |
elif context.get_remaining_time_in_millis() < MINIMUN_REMAINING_TIME_MS: | |
print('Breaking. Timeout soon. ') | |
print("inserting %d" % len(sql_parameter_sets)) | |
batch_insert(schema, sql_parameter_sets) | |
break | |
i+=1 | |
# if i>10001: | |
# print("RETURNING ") | |
# return | |
else: | |
print("reached end of file") | |
print("inserting %d" % len(sql_parameter_sets)) | |
batch_insert(schema, sql_parameter_sets) | |
return | |
new_offset = offset + bodylines.offset | |
if new_offset < s3_object.content_length: | |
new_event = { | |
**event, | |
"offset": new_offset, | |
"fieldnames": fieldnames, | |
"fieldtypes": fieldtypes, | |
} | |
print("Invoking Lambda!") | |
print(invoke_lambda(context.function_name, new_event)) | |
return | |
def invoke_lambda(function_name, event): | |
payload = json.dumps(event).encode('utf-8') | |
client = boto3.client('lambda') | |
response = client.invoke( | |
FunctionName=function_name, | |
InvocationType='Event', | |
Payload=payload | |
) | |
return response | |
def get_object_bodylines(s3_object, offset): | |
resp = s3_object.get(Range=f'bytes={offset}-') | |
body: botocore.response.StreamingBody = resp['Body'] | |
return BodyLines(body) | |
class BodyLines: | |
'''https://medium.com/swlh/processing-large-s3-files-with-aws-lambda-2c5840ae5c91''' | |
def __init__(self, body: botocore.response.StreamingBody, initial_offset=0): | |
self.body = body | |
self.offset = initial_offset | |
def iter_lines(self, chunk_size=1024): | |
"""Return an iterator to yield lines from the raw stream. | |
This is achieved by reading chunk of bytes (of size chunk_size) at a | |
time from the raw stream, and then yielding lines from there. | |
""" | |
pending = b'' | |
for chunk in self.body.iter_chunks(chunk_size): | |
lines = (pending + chunk).splitlines(True) | |
for line in lines[:-1]: | |
self.offset += len(line) | |
yield line.decode('utf-8') | |
pending = lines[-1] | |
if pending: | |
self.offset += len(pending) | |
yield pending.decode('utf-8') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"bucket_name": "", | |
"object_key": "", | |
"offset": 0, | |
"fieldnames": "", | |
"db_name": "testingDB", | |
"table_name": "analysis", | |
"secretArn": "", | |
"resourceArn": "" | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment