-
-
Save AkiyonKS/228eee9b431fd90dcbf008a691b65c33 to your computer and use it in GitHub Desktop.
main.py of train app
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
from PIL import Image | |
import io | |
import base64 | |
from flask import Flask, request, redirect, render_template, flash | |
from werkzeug.utils import secure_filename | |
import cv2 | |
import urllib.error | |
import urllib.request | |
import tensorflow as tf | |
app = Flask(__name__) | |
# 訓練済みモデル(.tflite)を取得 | |
def fetch_interpreter(): | |
print("fetch_interpreter") | |
model_name = "model_train_20220721_123305.tflite" | |
file_path = "./static/" + model_name | |
interpreter = tf.lite.Interpreter(model_path=file_path) | |
# メモリ確保。これはモデル読み込み直後に必須 | |
interpreter.allocate_tensors() | |
print("interpreter: ", type(interpreter)) | |
return interpreter | |
# モデルを取得して初期設定 | |
interpreter = fetch_interpreter() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
input_shape = input_details[0]['shape'] | |
# 読み込める画像ファイルの種類かどうか確認 | |
def allowed_file(filename): | |
allowed_extensions = set(['png', 'jpg', 'jpeg', 'gif']) | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions | |
# 画像の前処理 元画像のアスペクト比を維持したまま、縦横の大きさを揃えて、余白部分は黒にする | |
def preprocess(img): | |
h, w, c = img.shape | |
longest_edge = max(h, w) | |
top = 0 | |
bottom = 0 | |
left = 0 | |
right = 0 | |
if h < longest_edge: | |
diff_h = longest_edge - h | |
top = diff_h // 2 | |
bottom = diff_h - top | |
elif w < longest_edge: | |
diff_w = longest_edge - w | |
left = diff_w // 2 | |
right = diff_w - left | |
else: | |
pass | |
img = cv2.copyMakeBorder(img, top, bottom, left, right, | |
cv2.BORDER_CONSTANT, value=[0, 0, 0]) | |
return img | |
# 画像のパスを指定して、OpenCVで開き、サイズを変更、画像を取得 | |
def fetch_img_and_resize(img_path, img_size): | |
img = cv2.imread(img_path) | |
img = preprocess(img) | |
img = cv2.resize(img, img_size) | |
return img | |
# PIL画像を指定して、OpenCVに変換、サイズを変更、画像を取得 | |
def fetch_img_and_resize2(img_pil, img_size): | |
img_cv = pil2cvbgr(img_pil) | |
img = preprocess(img_cv) | |
img = cv2.resize(img, img_size) | |
return img | |
# PIL画像をOpenCVのBGRに変換 | |
def pil2cvbgr(image): | |
''' PIL型 -> OpenCV型 ''' | |
new_image = np.array(image, dtype=np.uint8) | |
if new_image.ndim == 2: # モノクロ | |
new_image = cv2.cvtColor(new_image, cv2.COLOR_GRAY2BGR) | |
elif new_image.shape[2] == 3: # カラー | |
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR) | |
elif new_image.shape[2] == 4: # 透過 | |
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGR) | |
return new_image | |
# 学習済みモデルから予測 | |
def pred_train(img_pil, img_size, labels): | |
print("pred_train") | |
# labels type is dic | |
img = fetch_img_and_resize2(img_pil, img_size) | |
input_data = np.array([img], dtype=np.float32) | |
interpreter.set_tensor(input_details[0]['index'], input_data) | |
# 推論実行 | |
interpreter.invoke() | |
# 推論結果は、output_detailsのindexに保存されている | |
output_data = interpreter.get_tensor(output_details[0]['index']) | |
print("output_data", output_data) | |
# 予測結果を辞書型で返す | |
pred_all2 = list(map(lambda x: "{:.5f}".format(x), output_data.reshape(1,-1)[0].tolist())) | |
for i, score in enumerate(pred_all2): | |
labels[str(i)]['pred_score'] = score | |
labels = dict(sorted(labels.items(), key=lambda x:x[1]['pred_score'], reverse=True)) | |
return labels | |
# urlからファイルを取得 | |
def fetch_url_file(url): | |
try: | |
with urllib.request.urlopen(url) as web_file: | |
print("type(web_file)", type(web_file)) | |
data = web_file.read() | |
print("type(data)", type(data)) | |
return(data) | |
except urllib.error.URLError as e: | |
print(e) | |
# モデルの予測に必要なラベル情報のcsvファイルを読み込んで辞書型にして取得 | |
def fetch_labels(): | |
file_path = "./static/labels_for_model.csv" | |
with open(file_path) as f: | |
reader = f.read() | |
rows = reader.split('\n') | |
colnames = rows[0].split(',') | |
rows = list(map(lambda v: v.split(','), rows[1:])) | |
keys = list(map(lambda row: row[0], rows)) | |
values = list(map(lambda row: dict(zip(colnames[1:], row[1:])), rows)) | |
dic = dict(zip(keys, values)) | |
return dic | |
@app.route('/', methods=['GET', 'POST']) | |
def main(): | |
img_size = 300 # 予測に使う画像サイズ | |
labels = fetch_labels() | |
qr_b64data = "" | |
check = "0" | |
img_pil = "" | |
if request.method == 'POST': | |
print(request) | |
print(request.files) | |
url = request.form.get('url') | |
filename = "" | |
if url != '' and allowed_file(os.path.basename(url)): | |
print('get url') | |
filename = secure_filename(os.path.basename(url)) | |
file = fetch_url_file(url) | |
# 読み込んだ画像をメモリ上に保存 | |
img_bytesio = io.BytesIO(file) | |
img_pil = Image.open(img_bytesio) | |
else: | |
print('get file') | |
if 'file' not in request.files: | |
flash('ファイルがありません') | |
return redirect(request.url) | |
file = request.files['file'] | |
if file.filename == '': | |
flash('ファイルがありません') | |
return redirect(request.url) | |
if file and allowed_file(file.filename): | |
filename = secure_filename(file.filename) | |
print("type(file)",type(file)) | |
img_pil = Image.open(file) | |
print("type(img_pil)", type(img_pil)) | |
img_bytesio = io.BytesIO() | |
img_pil.save(img_bytesio, 'png') | |
print("type(img_bytesio)", type(img_bytesio)) | |
if filename != "": | |
#受け取った画像を読み込み、np形式に変換 | |
labels = pred_train(img_pil, [img_size]*2, labels) | |
print(labels) | |
# index.htmlに渡す画像情報を作成 | |
qr_b64str = base64.b64encode(img_bytesio.getvalue()).decode("utf-8") | |
qr_b64data = "data:image/png;base64,{}".format(qr_b64str) | |
check = "1" | |
return render_template("index.html", scores=labels, img=qr_b64data, check=check) | |
if __name__ == "__main__": | |
port = int(os.environ.get('PORT', 8080)) | |
app.run(host ='0.0.0.0', port = port) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment