Skip to content

Instantly share code, notes, and snippets.

@nwatab
Last active August 6, 2021 06:17
Show Gist options
  • Save nwatab/20b9a36cda32f2e72e4093a1ef9936ab to your computer and use it in GitHub Desktop.
Save nwatab/20b9a36cda32f2e72e4093a1ef9936ab to your computer and use it in GitHub Desktop.
Flask image classification API by Keras and Flask

Simple image classification API

Usage

$ python app.py

Example

$ curl -X POST -F file=@cat.jpg 'http://localhost:5000/predict'
{
  "predictions": [
    {
      "label": "tiger_cat", 
      "probability": 0.4184308648109436
    }, 
    {
      "label": "Egyptian_cat", 
      "probability": 0.3616541028022766
    }, 
    {
      "label": "tabby", 
      "probability": 0.1262882500886917
    }, 
    {
      "label": "lynx", 
      "probability": 0.02171700820326805
    }, 
    {
      "label": "cup", 
      "probability": 0.00845610722899437
    }
  ], 
  "success": true
}

Reference

import os
from flask import Flask, request, redirect, url_for, jsonify, Response
from werkzeug.utils import secure_filename
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing import image
import numpy as np
from PIL import Image
import io
app = Flask(__name__)
model = None
def load_model():
global model
model = VGG16(weights='imagenet', include_top=True)
@app.route('/predict', methods=['GET', 'POST'])
def upload_file():
response = {'success': False}
if request.method == 'POST':
if request.files.get('file'): # image is stored as name "file"
img_requested = request.files['file'].read()
img = Image.open(io.BytesIO(img_requested))
if img.mode != 'RGB':
img = img.convert('RGB')
img = img.resize((224, 224))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
inputs = preprocess_input(img)
preds = model.predict(inputs)
results = decode_predictions(preds)
response['predictions'] = []
for (imagenetID, label, prob) in results[0]: # [0] as input is only one image
row = {'label': label, 'probability': float(prob)} # numpy float is not good for json
response['predictions'].append(row)
response['success'] = True
return jsonify(response)
return '''
<!doctype html>
<title>Upload new File</title>
<h1>Upload new File</h1>
<form method=post enctype=multipart/form-data>
<p><input type=file name=file>
<input type=submit value=Upload>
</form>
'''
if __name__ == '__main__':
load_model()
# no-thread: https://github.com/keras-team/keras/issues/2397#issuecomment-377914683
# avoid model.predict runs before model initiated
# To let this run on HEROKU, model.predict should run onece after initialized
app.run(threaded=False)
@er778899789
Copy link

Hey, It's nice flask application for reference. I have a question, I can run my web service successfully.
But I want to reload the service when I change my model. I know set debug = True can reload automatically when code change, also I can restart the program by myself.
The service will shutdown about 1 minute when I reload or restart the program. Are there any methods can change model but not stop the program?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment