Skip to content

Instantly share code, notes, and snippets.

@AkiyonKS
Created August 9, 2022 04:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AkiyonKS/228eee9b431fd90dcbf008a691b65c33 to your computer and use it in GitHub Desktop.
Save AkiyonKS/228eee9b431fd90dcbf008a691b65c33 to your computer and use it in GitHub Desktop.
main.py of train app
import os
import numpy as np
from PIL import Image
import io
import base64
from flask import Flask, request, redirect, render_template, flash
from werkzeug.utils import secure_filename
import cv2
import urllib.error
import urllib.request
import tensorflow as tf
app = Flask(__name__)
# 訓練済みモデル(.tflite)を取得
def fetch_interpreter():
print("fetch_interpreter")
model_name = "model_train_20220721_123305.tflite"
file_path = "./static/" + model_name
interpreter = tf.lite.Interpreter(model_path=file_path)
# メモリ確保。これはモデル読み込み直後に必須
interpreter.allocate_tensors()
print("interpreter: ", type(interpreter))
return interpreter
# モデルを取得して初期設定
interpreter = fetch_interpreter()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
# 読み込める画像ファイルの種類かどうか確認
def allowed_file(filename):
allowed_extensions = set(['png', 'jpg', 'jpeg', 'gif'])
return '.' in filename and filename.rsplit('.', 1)[1].lower() in allowed_extensions
# 画像の前処理 元画像のアスペクト比を維持したまま、縦横の大きさを揃えて、余白部分は黒にする
def preprocess(img):
h, w, c = img.shape
longest_edge = max(h, w)
top = 0
bottom = 0
left = 0
right = 0
if h < longest_edge:
diff_h = longest_edge - h
top = diff_h // 2
bottom = diff_h - top
elif w < longest_edge:
diff_w = longest_edge - w
left = diff_w // 2
right = diff_w - left
else:
pass
img = cv2.copyMakeBorder(img, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=[0, 0, 0])
return img
# 画像のパスを指定して、OpenCVで開き、サイズを変更、画像を取得
def fetch_img_and_resize(img_path, img_size):
img = cv2.imread(img_path)
img = preprocess(img)
img = cv2.resize(img, img_size)
return img
# PIL画像を指定して、OpenCVに変換、サイズを変更、画像を取得
def fetch_img_and_resize2(img_pil, img_size):
img_cv = pil2cvbgr(img_pil)
img = preprocess(img_cv)
img = cv2.resize(img, img_size)
return img
# PIL画像をOpenCVのBGRに変換
def pil2cvbgr(image):
''' PIL型 -> OpenCV型 '''
new_image = np.array(image, dtype=np.uint8)
if new_image.ndim == 2: # モノクロ
new_image = cv2.cvtColor(new_image, cv2.COLOR_GRAY2BGR)
elif new_image.shape[2] == 3: # カラー
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
elif new_image.shape[2] == 4: # 透過
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGR)
return new_image
# 学習済みモデルから予測
def pred_train(img_pil, img_size, labels):
print("pred_train")
# labels type is dic
img = fetch_img_and_resize2(img_pil, img_size)
input_data = np.array([img], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
# 推論実行
interpreter.invoke()
# 推論結果は、output_detailsのindexに保存されている
output_data = interpreter.get_tensor(output_details[0]['index'])
print("output_data", output_data)
# 予測結果を辞書型で返す
pred_all2 = list(map(lambda x: "{:.5f}".format(x), output_data.reshape(1,-1)[0].tolist()))
for i, score in enumerate(pred_all2):
labels[str(i)]['pred_score'] = score
labels = dict(sorted(labels.items(), key=lambda x:x[1]['pred_score'], reverse=True))
return labels
# urlからファイルを取得
def fetch_url_file(url):
try:
with urllib.request.urlopen(url) as web_file:
print("type(web_file)", type(web_file))
data = web_file.read()
print("type(data)", type(data))
return(data)
except urllib.error.URLError as e:
print(e)
# モデルの予測に必要なラベル情報のcsvファイルを読み込んで辞書型にして取得
def fetch_labels():
file_path = "./static/labels_for_model.csv"
with open(file_path) as f:
reader = f.read()
rows = reader.split('\n')
colnames = rows[0].split(',')
rows = list(map(lambda v: v.split(','), rows[1:]))
keys = list(map(lambda row: row[0], rows))
values = list(map(lambda row: dict(zip(colnames[1:], row[1:])), rows))
dic = dict(zip(keys, values))
return dic
@app.route('/', methods=['GET', 'POST'])
def main():
img_size = 300 # 予測に使う画像サイズ
labels = fetch_labels()
qr_b64data = ""
check = "0"
img_pil = ""
if request.method == 'POST':
print(request)
print(request.files)
url = request.form.get('url')
filename = ""
if url != '' and allowed_file(os.path.basename(url)):
print('get url')
filename = secure_filename(os.path.basename(url))
file = fetch_url_file(url)
# 読み込んだ画像をメモリ上に保存
img_bytesio = io.BytesIO(file)
img_pil = Image.open(img_bytesio)
else:
print('get file')
if 'file' not in request.files:
flash('ファイルがありません')
return redirect(request.url)
file = request.files['file']
if file.filename == '':
flash('ファイルがありません')
return redirect(request.url)
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
print("type(file)",type(file))
img_pil = Image.open(file)
print("type(img_pil)", type(img_pil))
img_bytesio = io.BytesIO()
img_pil.save(img_bytesio, 'png')
print("type(img_bytesio)", type(img_bytesio))
if filename != "":
#受け取った画像を読み込み、np形式に変換
labels = pred_train(img_pil, [img_size]*2, labels)
print(labels)
# index.htmlに渡す画像情報を作成
qr_b64str = base64.b64encode(img_bytesio.getvalue()).decode("utf-8")
qr_b64data = "data:image/png;base64,{}".format(qr_b64str)
check = "1"
return render_template("index.html", scores=labels, img=qr_b64data, check=check)
if __name__ == "__main__":
port = int(os.environ.get('PORT', 8080))
app.run(host ='0.0.0.0', port = port)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment