Skip to content

Instantly share code, notes, and snippets.

@dormeir999
Last active December 24, 2021 17:04
Show Gist options
  • Save dormeir999/9ab3a0f2ce20d380bdc89f0df2460c64 to your computer and use it in GitHub Desktop.
Save dormeir999/9ab3a0f2ce20d380bdc89f0df2460c64 to your computer and use it in GitHub Desktop.
Flask_deplyoment_ML_model_prediction
import imghdr
import sys
from io import BytesIO
import yaml
from flask import Flask, send_file
from flask_restful import reqparse, Api, Resource
import werkzeug
import os
from utils import last_changed_file
from zipfile import ZipFile
import json
import base64
from PIL import Image
app = Flask(__name__)
api = Api(app)
class Models(Resource):
def __init__(self):
self.parser = reqparse.RequestParser()
self.algorithms = ['FastFCN','BERT','LSTM','XGboost']
def get(self):
return {"algorithms": self.algorithms}
def post(self):
self.parser.add_argument('algorithm')
self.parser.add_argument('Model_path') # default is false
self.parser.add_argument('return_weights', action='store_true') # default is false
self.parser.add_argument('return_classes', action='store_true') # default is false
args = self.parser.parse_args()
algorithm = args.get("algorithm")
model = args.get("Model_path")
return_weights = args.get("return_weights")
return_classes = args.get("return_classes")
if algorithm == 'FastFCN':
algorithm_dir = configs['FastFCN_DEFAULT_DIR']
models_dir = os.path.join(algorithm_dir, 'logs')
models = os.listdir(models_dir)
if model:
if not model in model:
return f"Model {model} not in {algorithm}'s pre-trained models"
else:
return f"Model_path was not passed to the API"
model_dir = os.path.join(models_dir, model)
if return_weights:
weights = [file for file in os.listdir(model_dir) if not file.endswith(".json")]
return {"weights": weights}
if return_classes:
config_file = [file for file in os.listdir(model_dir) if file.endswith(".json")]
if not config_file:
return f"Model {model} configuration was not found. Model is unsuitable for prediction."
config_file = os.path.join(model_dir, config_file[0])
with open(config_file) as jsonFile:
jsonObject = json.load(jsonFile)
jsonFile.close()
return {"classes": jsonObject['class_names']}
return "Add either 'return_weights':'True' or 'return_classes':'True' to your arguments"
return f"Algorithm {algorithm} is not implemented on Awesome app"
class Configs(Resource):
def __init__(self):
self.parser = reqparse.RequestParser()
self.upload_path = os.path.join(configs['DEFAULT_TEST_DIR'], 'API')
def get(self):
json_file = [file for file in os.listdir(self.upload_path) if file == "configs.json"]
if json_file:
with open(os.path.join(self.upload_path, json_file[0])) as jsonFile:
jsonObject = json.load(jsonFile)
jsonFile.close()
return jsonObject
def post(self):
self.parser.add_argument("file", type=werkzeug.datastructures.FileStorage, location='files')
args = self.parser.parse_args()
file = args.get("file")
configs_filename = os.path.join(self.upload_path, "configs.json")
file.save(configs_filename)
return f"saved {file.filename} succesfully"
class Predict(Resource):
def __init__(self):
self.parser = reqparse.RequestParser()
self.results_path = os.path.join(configs['DEFAULT_OUTPUT_PATH'], 'API')
self.upload_path = os.path.join(configs['DEFAULT_TEST_DIR'], 'API')
def get(self):
return {"Last prediction files": os.listdir(self.upload_path)}
def clean_folder(self, folder):
try:
[os.remove(os.path.join(folder, file)) for file in os.listdir(folder) if not file == "configs.json"]
except:
print(f"Couldn't delete all files in {folder}")
def post(self):
# Create settings from pre-existing configs file
self.settings = Configs.get(self)
# Clean upload folder
self.clean_folder(self.upload_path)
# clean existing results folder
results_folder = os.path.join(self.results_path, self.settings['weight'].split(".")[0])
if os.path.exists(results_folder):
self.clean_folder(results_folder)
# Recieve list of base64 images
self.parser.add_argument("images", action='append')
self.parser.add_argument('create_images', action='store_true') # default is false
self.parser.add_argument('return_zip', action='store_true') # default is false
args = self.parser.parse_args()
images = args.get("images")
is_create_images = args.get("create_images")
is_return_zip = args.get("return_zip")
# if only one image without list, turn into a list with one member
if not isinstance(images, list):
images = list(images)
# save all base64 images to jpg file (if png image, convert to jpg)
for i, image in enumerate(images):
filename = os.path.join(self.upload_path, f"{i}.jpg")
if self.is_png(image):
self.png_string_to_jpg_file(image, filename)
else:
with open(filename, "wb") as fh:
fh.write(base64.b64decode(image))
# predict all files, get results jsons
predict_exit_message = self.predict()
results_dir = os.path.join(self.results_path, last_changed_file(self.results_path))
results_jsons = [os.path.join(results_dir, file) for file in os.listdir(results_dir) if file.endswith(".json")]
# if user wants the images back, create results and continute
if is_create_images:
if predict_exit_message == "Finished sucessfully": # only if jsons were created, continute to create predicted images
self.create_results()
else:
return predict_exit_message
else: # return only results json
res = []
for results_json in results_jsons:
with open(results_json) as f:
res.append(json.load(f))
if is_return_zip:
zip_name = self.zip_prediction()
return send_file(zip_name)
else:
return {"Responses": res}
# get all predicted images and results jsons
predicted_images = [os.path.join(results_dir, file) for file in os.listdir(results_dir) if file.endswith(".jpg")]
# create a list, each member is a list of two members: base64 predicted image, results json
res = []
for predicted_image, results_json in zip(predicted_images, results_jsons):
image_json_result = [None, None]
with open(results_json) as f:
image_json_result[0] = json.load(f)
with open(predicted_image, "rb") as img_file:
image_json_result[1] = base64.b64encode(img_file.read()).decode('utf-8')
res.append(image_json_result)
if is_return_zip:
zip_name = self.zip_prediction()
return send_file(zip_name)
else:
return {"Responses": res}
api.add_resource(Predict, '/predict')
api.add_resource(Configs, '/configs')
api.add_resource(Models, '/models')
api.add_resource(Base64, '/base64')
if __name__ == '__main__':
app.run(debug=True, port=8503, host="0.0.0.0")#, ssl_context=('cert.pem', 'key.pem'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment