-
-
Save sararob/1fb9fb132a93bdda95f7a71d2afd38ad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import googleapiclient | |
from googleapiclient import discovery | |
from flask import jsonify | |
import json | |
import numpy as np | |
def predict_json(project, model, instances, version=None): | |
print(instances) | |
service = googleapiclient.discovery.build('ml', 'v1') | |
name = 'projects/{}/models/{}'.format(project, model) | |
if version is not None: | |
name += '/versions/{}'.format(version) | |
response = service.projects().predict( | |
name=name, | |
body={"instances": [instances]}).execute() | |
if 'error' in response: | |
print('error in calling Prediction API') | |
raise RuntimeError(response['error']) | |
return response['predictions'][0] | |
def scale_data(data): | |
egg_conversion = .2 | |
tsp_conversion = .02 | |
data['eggs'] = data['eggs'] * egg_conversion | |
data['baking_powder'] = data['baking_powder'] * tsp_conversion | |
data['baking_soda'] = data['baking_soda'] * tsp_conversion | |
data['salt'] = data['salt'] * tsp_conversion | |
data['yeast'] = data['yeast'] * tsp_conversion | |
# My model takes an array of ingredient amounts, this is the order it's expecting | |
# TODO: should probably have my model take key-val pairs instead | |
order = ['flour', 'sugar', 'starter', 'salt', 'yeast', 'milk', 'water', 'oil', 'eggs', 'baking_powder', 'baking_soda', 'butter'] | |
amt_sum = np.sum(list(data.values())) | |
new_arr = [] | |
for i in order: | |
new_arr.append(data[i] / amt_sum) | |
return new_arr | |
def get_prediction(request): | |
if request.method == 'OPTIONS': | |
# Allows GET requests from any origin with the Content-Type | |
# header and caches preflight response for an 3600s | |
headers = { | |
'Access-Control-Allow-Origin': '*', | |
'Access-Control-Allow-Methods': 'POST', | |
'Access-Control-Allow-Headers': 'Content-Type', | |
'Access-Control-Max-Age': '3600' | |
} | |
return ('', 204, headers) | |
req_data = request.get_json() | |
print(req_data) | |
# Scale the data to the same units and percentages | |
scaled = scale_data(req_data) | |
print(scaled) | |
# Call the deployed model | |
prediction = predict_json('your-gcp-project', 'model-name', scaled) | |
print(prediction) | |
# Get the highest confidence prediction + return some data to display it | |
label_map = ['Bread', 'Cake', 'Cookies'] | |
predicted_ind = np.argmax(prediction) | |
baked_prediction = label_map[predicted_ind] | |
confidence = str(round(prediction[predicted_ind] * 100)) | |
print(confidence) | |
bake_val = '' | |
bake_emoji = '' | |
if baked_prediction == 'Bread': | |
bake_val = "bread" | |
bake_emoji = "🍞" | |
elif baked_prediction == 'Cake': | |
bake_val = "cake" | |
bake_emoji = "🧁" | |
elif baked_prediction == 'Cookies': | |
bake_val = "cookies" | |
bake_emoji = "🍪" | |
return_data = { | |
"bake_confidence": confidence, | |
"bake_value": bake_val, | |
"bake_emoji": bake_emoji | |
} | |
return (jsonify(return_data), 200, headers) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment