Skip to content

Instantly share code, notes, and snippets.

Created November 30, 2020 06:53
Show Gist options
  • Save PhanDuc/0dbc6a648202b8dc335f9f09019d539f to your computer and use it in GitHub Desktop.
Save PhanDuc/0dbc6a648202b8dc335f9f09019d539f to your computer and use it in GitHub Desktop.
from fastapi import FastAPI, File, UploadFile, HTTPException
from PIL import Image
from pydantic import BaseModel
from tensorflow.keras.models import load_model
from typing import List
import io
import numpy as np
import sys
import uvicorn
# load model
#filepath = "./saved_model"
filepath = '/media/Another/Computer_Science_Project/fastapi_learning/keras_fastapi/saved_model'
# model = load_model(filepath, compile=True)
model = load_model(filepath)
# get the input shape for the model layer
input_shape = model.layers[0].input_shape
# Define the Response
class Prediction(BaseModel):
filename: str
contenttype: str
prediction: List[float] = []
likely_class: int
# define the fastAPI
app = FastAPI()
# define response
def root_route():
return {'error': 'Use GET /prediction instead of the root route!'}
# define the prediction route"/prediction/", response_model=Prediction)
async def prediction_route(file: UploadFile = File(...)):
# ensure that this is an image
if file.content_type.startswith("image/") is False:
raise HTTPException(status_code=400,
detail=f'File \'{file.filename}\' is not an image.'
# read image contain
contents = await
pil_image =
# resize image to expected input shape
pil_image = pil_image.resize((input_shape[1], input_shape[2]))
# convert image into grayscale
if input_shape[3] and input_shape[3] == 1:
pil_image = pil_image.convert('L')
# convert imgae to numpy format
numpy_image = np.array(pil_image).reshape((input_shape[1], input_shape[2], input_shape[3]))
# scale data
numpy_image = numpy_image / 255.0
# generate prediction
prediction_array = np.array([numpy_image])
predictions = model.predict(prediction_array)
prediction = predictions[0]
likely_class = np.argmax(prediction)
return {
"filename": file.filename,
"contenttype": file.content_type,
"prediction": prediction.tolist(),
"likely_class": likely_class
e = sys.exc_info()[1]
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':, debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment