Skip to content

Instantly share code, notes, and snippets.

@jeffmylife
Last active November 18, 2020 23:26
Show Gist options
  • Save jeffmylife/e10b5f64ce9c1f852604877d16efe001 to your computer and use it in GitHub Desktop.
Save jeffmylife/e10b5f64ce9c1f852604877d16efe001 to your computer and use it in GitHub Desktop.
Transfer a large csv file on S3 to RDS Serverless through lambda function.
import csv
import json
import os
import boto3
import botocore.response
from pprint import pprint
MINIMUN_REMAINING_TIME_MS = int(os.getenv('MINIMUM_REMAINING_TIME_MS') or 10000)
def type_mapper(mysql_type_str, value_to_map)->tuple :
"""Maps mysql type to the boto3 rds api type"""
t = mysql_type_str.upper()
if 'DOUBLE' in t:
try:
return "doubleValue", float(value_to_map)
except:
return "isNull", True
if 'VARCHAR' in t:
return "stringValue", str(value_to_map)
if "INT" in t:
try:
return "longValue", int(value_to_map)
except:
return "isNull", True
raise Exception(f"Unknown mysql type: {mysql_type_str}")
def handler(event, context):
# get values from event
bucket_name = event['bucket_name']
object_key = event['object_key']
resourceArn = event['resourceArn']
secretArn = event['secretArn']
db_name = event['db_name']
table_name = event['table_name']
offset = event.get('offset', 0)
fieldnames = event.get('fieldnames', None)
fieldtypes = event.get('fieldtypes', None)
# aws resources
s3_resource = boto3.resource('s3')
s3_object = s3_resource.Object(bucket_name=bucket_name, key=object_key)
rds = boto3.client('rds-data')
# rds schema
assert db_name and table_name, "must give db name and table name"
response = rds.execute_statement(
secretArn = secretArn,
database = db_name,
resourceArn = resourceArn,
sql = f'DESCRIBE {table_name}'
)
assert response['ResponseMetadata']['HTTPStatusCode']==200
assert 'records' in response
records = response['records']
schema = [ (record[0]['stringValue'], record[1]['stringValue'])
for record in records
]
# pprint(schema)
# init csv reader
bodylines = get_object_bodylines(s3_object, offset)
to_iter = bodylines.iter_lines()
if not fieldnames:
assert offset==0, "Fieldnames should set only at offset = 0. "
csv_fieldnames = next(to_iter).split(",") ## notice next skips the first line from now on
csv_fieldnames = list(i.strip() for i in fieldnames)
fieldnames = list(n for n, t in schema)
fieldtypes = list(t for n, t in schema)
# quality check
s1, s2 = set(csv_fieldnames), set(fieldnames)
if s1 != s2:
print(f"Warning: file field names do not match existing columns in database:\n\t{s1 ^ s2}")
csv_reader = csv.DictReader(to_iter, fieldnames=fieldnames)
def batch_insert(schema, sql_parameters):
fieldnames = list(n for n, t in schema)
sql = f'''
INSERT INTO {table_name} ({str(', '.join(fieldnames))})
VALUES ( :{str(', :'.join(fieldnames))})
'''
# TODO handle failure here with https://stackoverflow.com/questions/58192747/aws-aurora-serverless-communication-link-failure
response = rds.batch_execute_statement(
secretArn=secretArn,
database=db_name,
resourceArn=resourceArn,
sql=sql,
parameterSets = sql_parameters
)
return response
i = 0
sql_parameter_sets = []
for row in csv_reader:
## process and do work
entry = []
for column_name, column_type in schema:
type_value, casted_value = type_mapper(column_type, row[column_name])
# if type_value != 'stringValue' and casted_value < 0:
# print(f"FAILURE @ i={i} with column_name={column_name} column_type={column_type} type_value={type_value}, casted_value={casted_value}")
entry.append({'name':column_name,
'value':{type_value: casted_value}})
sql_parameter_sets.append(entry)
# TODO: make it dependent on data size rather than number of rows; if you don't, lambda runs out of memory
if (i+1) % 3000 == 0: # save every %n rows
print("inserting %d" % len(sql_parameter_sets))
batch_insert(schema, sql_parameter_sets)
sql_parameter_sets = []
elif context.get_remaining_time_in_millis() < MINIMUN_REMAINING_TIME_MS:
print('Breaking. Timeout soon. ')
print("inserting %d" % len(sql_parameter_sets))
batch_insert(schema, sql_parameter_sets)
break
i+=1
# if i>10001:
# print("RETURNING ")
# return
else:
print("reached end of file")
print("inserting %d" % len(sql_parameter_sets))
batch_insert(schema, sql_parameter_sets)
return
new_offset = offset + bodylines.offset
if new_offset < s3_object.content_length:
new_event = {
**event,
"offset": new_offset,
"fieldnames": fieldnames,
"fieldtypes": fieldtypes,
}
print("Invoking Lambda!")
print(invoke_lambda(context.function_name, new_event))
return
def invoke_lambda(function_name, event):
payload = json.dumps(event).encode('utf-8')
client = boto3.client('lambda')
response = client.invoke(
FunctionName=function_name,
InvocationType='Event',
Payload=payload
)
return response
def get_object_bodylines(s3_object, offset):
resp = s3_object.get(Range=f'bytes={offset}-')
body: botocore.response.StreamingBody = resp['Body']
return BodyLines(body)
class BodyLines:
'''https://medium.com/swlh/processing-large-s3-files-with-aws-lambda-2c5840ae5c91'''
def __init__(self, body: botocore.response.StreamingBody, initial_offset=0):
self.body = body
self.offset = initial_offset
def iter_lines(self, chunk_size=1024):
"""Return an iterator to yield lines from the raw stream.
This is achieved by reading chunk of bytes (of size chunk_size) at a
time from the raw stream, and then yielding lines from there.
"""
pending = b''
for chunk in self.body.iter_chunks(chunk_size):
lines = (pending + chunk).splitlines(True)
for line in lines[:-1]:
self.offset += len(line)
yield line.decode('utf-8')
pending = lines[-1]
if pending:
self.offset += len(pending)
yield pending.decode('utf-8')
{
"bucket_name": "",
"object_key": "",
"offset": 0,
"fieldnames": "",
"db_name": "testingDB",
"table_name": "analysis",
"secretArn": "",
"resourceArn": ""
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment