Created
October 1, 2019 15:58
-
-
Save smellslikeml/a977445cc192e4748f9ed8df9804562a to your computer and use it in GitHub Desktop.
Lambda function that downloads image classification tflite model + labels, listens to Kinesis stream, and prints predicted labels
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import os | |
import io | |
import boto3 | |
import base64 | |
import numpy as np | |
import zipfile | |
from PIL import Image | |
from tflite_runtime.interpreter import Interpreter | |
bkt = os.environ['bkt'] | |
direc = os.environ['dir'] | |
s3 = boto3.resource('s3') | |
bucket = s3.Bucket(bkt) | |
get_last_modified = lambda obj: int(obj.last_modified.strftime('%s')) | |
def get_new_model(prefix='mdls'): | |
""" | |
Returns a list of file paths for model and label file | |
""" | |
unsorted = [file for file in bucket.objects.filter(Prefix=prefix)] | |
mdl = sorted(unsorted, key=get_last_modified, reverse=True)[0] | |
mdl_path = '/tmp/' + mdl.key.split('/')[-1] | |
bucket.download_file(mdl.key, mdl_path) | |
with zipfile.ZipFile(mdl_path,"r") as zip_ref: | |
zip_ref.extractall('/tmp/') | |
paths = os.listdir('/tmp/') | |
return ['/tmp/'+ x for x in paths] | |
# Download model + labels from S3 | |
mdl_path = get_new_model(prefix=direc) | |
label_path = [x for x in mdl_path if x.endswith('txt')][0] | |
mdl_path = [x for x in mdl_path if x.endswith('tflite')][0] | |
# Load labels | |
with open(label_path, "r") as fl: | |
labels = fl.readlines() | |
labels = [x.replace('\n', '') for x in labels] | |
label_dict = dict(zip(range(0,len(labels)), labels)) | |
# Load model | |
interpreter = Interpreter(model_path=mdl_path) | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
def lambda_handler(event, context): | |
for record in event['Records']: | |
#Kinesis data is base64 encoded so decode here | |
payload = base64.b64decode(record['kinesis']['data']) | |
image = Image.open(io.BytesIO(payload)) | |
image = image.resize((input_details[0]['shape'][1],input_details[0]['shape'][2])) | |
data = np.expand_dims(np.asarray(image).astype(input_details[0]['dtype'])[:, :, :3], axis=0) / 255 | |
interpreter.set_tensor(input_details[0]['index'], data) | |
interpreter.invoke() | |
result = interpreter.get_tensor(output_details[0]['index']) | |
print("Prediction: " + str(label_dict[np.argmax(result)])) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment