Skip to content

Instantly share code, notes, and snippets.

@deepak-karkala
Created December 21, 2020 07:08
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save deepak-karkala/a1b3e42b52cc44748bcdb708474b7b01 to your computer and use it in GitHub Desktop.
Save deepak-karkala/a1b3e42b52cc44748bcdb708474b7b01 to your computer and use it in GitHub Desktop.
FLASK Webapp for Image Segmentation Model
# FLASK Webapp for Image Segmentation Model
import os, sys, io
sys.path.append(".")
import webapp
from flask import Flask
import flask
import numpy as np
import pandas as pd
from webapp.db import get_db, init_app
import tensorflow as tf
from tensorflow import keras
from PIL import Image
BASE_PATH = "webapp/";
def load_model():
# Load pre-trained machine learning model.
model = keras.models.load_model(BASE_PATH + "static/models/model.h5")
return model
model = load_model()
def create_app(test_config=None):
# create and configure the app
app = Flask(__name__, instance_relative_config=True)
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
app.config.from_mapping(
SECRET_KEY='dev',
DATABASE=os.path.join(app.instance_path, 'flaskr.sqlite'),
)
if test_config is None:
# load the instance config, if it exists, when not testing
app.config.from_pyfile('config.py', silent=True)
else:
# load the test config if passed in
app.config.from_mapping(test_config)
# ensure the instance folder exists
try:
os.makedirs(app.instance_path)
except OSError:
pass
# Landing page
@app.route('/', methods=['GET', 'POST'])
def hello():
# Return landing page
if flask.request.method == 'GET':
# Get image paths from db
db = get_db()
all_images = db.execute(
'SELECT id, category, image_path FROM product'
).fetchall()
selected_image = None
filter_product_category = None
product_segmented_image_path = None
# Return all images to be showcased
return(flask.render_template('base.html', images=all_images, selected_image=selected_image,
filter_product_category=filter_product_category, product_segmented_image_path=product_segmented_image_path))
# Return prediction output
if flask.request.method == 'POST':
db = get_db()
selected_image_id = flask.request.form['product_radio']
# Get selected image from db
selected_image = db.execute(
'SELECT id, category, image_path FROM product p WHERE p.id = ' + selected_image_id
).fetchone()
# Get segmentation mask from model (Run inference on model)
full_image_path = BASE_PATH + "static/images/" + selected_image["image_path"]
product_segmented_image_path = get_model_output_segmentation_mask(full_image_path)
return(flask.render_template('base.html', images=all_images, selected_image=selected_image,
filter_product_category=filter_product_category, product_segmented_image_path=product_segmented_image_path))
init_app(app)
return app
# if this is the main thread of execution first load the model and
# then start the server
if __name__ == "__main__":
print(("* Loading Keras model and Flask starting server..."
"please wait until server has fully started"))
app = create_app()
app.run(host='localhost', port=5000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment