Created
February 28, 2019 14:29
-
-
Save YHaruoka/c2fd56d1b3a0e1885e4d92628b696477 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras.models import model_from_json | |
from keras.preprocessing.image import load_img, img_to_array | |
# ------------------------------------------------------------------------------------- | |
# モデルの読み込み部 | |
# ------------------------------------------------------------------------------------- | |
# 入力画像サイズ - 訓練時の画像サイズと合わせる | |
INPUT_IMAGE_SIZE = 224 | |
# GrayScaleのときに1、COLORのときに3にする - 訓練時のカラーチャンネル数と合わせる | |
COLOR_CHANNEL = 3 | |
# 確認したい画像へのフルパス | |
TEST_PATH = "..\\test_dataset\\test_image.jpg" | |
# 今回利用するアーキテクチャと重みのファイルへのパス | |
MODEL_ARC_PATH = 'model_architecture.json' | |
WEIGHTS_PATH = 'weights.hdf5' | |
# ------------------------------------------------------------------------------------- | |
# テスト画像入力部 | |
# ------------------------------------------------------------------------------------- | |
# 今回は1枚の画像だが複数画像対応も可能 | |
test_images = [] | |
# テスト画像の入力部分 | |
if COLOR_CHANNEL == 1: | |
img = load_img(TEST_PATH, color_mode = "grayscale", target_size=(INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)) | |
elif COLOR_CHANNEL == 3: | |
img = load_img(TEST_PATH, color_mode = "rgb", target_size=(INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)) | |
array = img_to_array(img) | |
test_images.append(array) | |
test_images = np.array(test_images) | |
# imageの画素値をint型からfloat型にする | |
test_images = test_images.astype('float32') | |
# 画素値を[0~255]⇒[0~1]とする | |
test_images = test_images / 255.0 | |
# ------------------------------------------------------------------------------------- | |
# モデル読み込み部分 | |
# ------------------------------------------------------------------------------------- | |
# JSONファイルからモデルのアーキテクチャを得る | |
model_arc_str = open(MODEL_ARC_PATH).read() | |
model = model_from_json(model_arc_str) | |
# モデル構成の確認 | |
model.summary() | |
# モデルの重みを得る | |
model.load_weights(WEIGHTS_PATH) | |
# ------------------------------------------------------------------------------------- | |
# テスト実行部 | |
# ------------------------------------------------------------------------------------- | |
# テストの実行 | |
result = model.predict(test_images, batch_size=1) | |
# 結果の表示 | |
print('result : ', result) | |
max_index = np.argmax(result) | |
print('max_index : ', max_index) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment