Created
August 7, 2021 19:54
-
-
Save detrin/90d8563ec8e74a856552ebb9c29e846a to your computer and use it in GitHub Desktop.
Serve ML model with Flask REST API - 3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from flask import Flask, request, jsonify | |
import traceback | |
import pandas as pd | |
import numpy as np | |
import json | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
from tensorflow.keras.layers.experimental import preprocessing | |
app = Flask(__name__) | |
cache = {} | |
def get_data(): | |
# Taken from https://www.tensorflow.org/tutorials/keras/regression | |
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data" | |
column_names = [ | |
"MPG", | |
"Cylinders", | |
"Displacement", | |
"Horsepower", | |
"Weight", | |
"Acceleration", | |
"Model Year", | |
"Origin", | |
] | |
raw_dataset = pd.read_csv( | |
url, | |
names=column_names, | |
na_values="?", | |
comment="\t", | |
sep=" ", | |
skipinitialspace=True, | |
) | |
dataset = raw_dataset.copy() | |
dataset = dataset.dropna().astype(np.float32) | |
dataset["Origin"] = dataset["Origin"].map({1: "USA", 2: "Europe", 3: "Japan"}) | |
dataset = pd.get_dummies(dataset, columns=["Origin"], prefix="", prefix_sep="") | |
train_dataset = dataset.sample(frac=0.8, random_state=0) | |
test_dataset = dataset.drop(train_dataset.index) | |
train_features = train_dataset.copy() | |
test_features = test_dataset.copy() | |
train_labels = train_features.pop("MPG") | |
test_labels = test_features.pop("MPG") | |
return train_features, test_features, train_labels, test_labels | |
def get_model(train_features): | |
normalizer = preprocessing.Normalization(axis=-1) | |
normalizer.adapt(np.array(train_features)) | |
model = keras.Sequential( | |
[ | |
normalizer, | |
layers.Dense(64, activation="relu"), | |
layers.Dense(64, activation="relu"), | |
layers.Dense(1), | |
] | |
) | |
model.compile(loss="mean_absolute_error", optimizer=tf.keras.optimizers.Adam(0.001)) | |
return model | |
def train_model(): | |
train_features, test_features, train_labels, test_labels = get_data() | |
model = get_model(train_features) | |
model.summary() | |
history = model.fit( | |
train_features, train_labels, validation_split=0.2, verbose=2, epochs=100 | |
) | |
return model | |
@app.route("/predict", methods=["POST"]) | |
def predict(): | |
try: | |
json_ = request.json | |
data = json_["data"] | |
prediction = cache["model"].predict([data]) | |
return jsonify({"prediction": str(prediction[0][0])}) | |
except: | |
return jsonify({"trace": traceback.format_exc()}) | |
if __name__ == "__main__": | |
try: | |
port = int(sys.argv[1]) | |
except: | |
port = 3000 | |
cache["model"] = train_model() | |
print("Model loaded") | |
app.run(port=port, debug=True) # , debug=True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment