Skip to content

Instantly share code, notes, and snippets.

@detrin
Created August 7, 2021 19:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save detrin/90d8563ec8e74a856552ebb9c29e846a to your computer and use it in GitHub Desktop.
Save detrin/90d8563ec8e74a856552ebb9c29e846a to your computer and use it in GitHub Desktop.
Serve ML model with Flask REST API - 3
import sys
from flask import Flask, request, jsonify
import traceback
import pandas as pd
import numpy as np
import json
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
app = Flask(__name__)
cache = {}
def get_data():
# Taken from https://www.tensorflow.org/tutorials/keras/regression
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data"
column_names = [
"MPG",
"Cylinders",
"Displacement",
"Horsepower",
"Weight",
"Acceleration",
"Model Year",
"Origin",
]
raw_dataset = pd.read_csv(
url,
names=column_names,
na_values="?",
comment="\t",
sep=" ",
skipinitialspace=True,
)
dataset = raw_dataset.copy()
dataset = dataset.dropna().astype(np.float32)
dataset["Origin"] = dataset["Origin"].map({1: "USA", 2: "Europe", 3: "Japan"})
dataset = pd.get_dummies(dataset, columns=["Origin"], prefix="", prefix_sep="")
train_dataset = dataset.sample(frac=0.8, random_state=0)
test_dataset = dataset.drop(train_dataset.index)
train_features = train_dataset.copy()
test_features = test_dataset.copy()
train_labels = train_features.pop("MPG")
test_labels = test_features.pop("MPG")
return train_features, test_features, train_labels, test_labels
def get_model(train_features):
normalizer = preprocessing.Normalization(axis=-1)
normalizer.adapt(np.array(train_features))
model = keras.Sequential(
[
normalizer,
layers.Dense(64, activation="relu"),
layers.Dense(64, activation="relu"),
layers.Dense(1),
]
)
model.compile(loss="mean_absolute_error", optimizer=tf.keras.optimizers.Adam(0.001))
return model
def train_model():
train_features, test_features, train_labels, test_labels = get_data()
model = get_model(train_features)
model.summary()
history = model.fit(
train_features, train_labels, validation_split=0.2, verbose=2, epochs=100
)
return model
@app.route("/predict", methods=["POST"])
def predict():
try:
json_ = request.json
data = json_["data"]
prediction = cache["model"].predict([data])
return jsonify({"prediction": str(prediction[0][0])})
except:
return jsonify({"trace": traceback.format_exc()})
if __name__ == "__main__":
try:
port = int(sys.argv[1])
except:
port = 3000
cache["model"] = train_model()
print("Model loaded")
app.run(port=port, debug=True) # , debug=True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment