Last active
October 27, 2017 15:59
-
-
Save adash333/5d1fe3b0a17059a3ffe3f16063c59054 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# original code from https://github.com/fchollet/keras/blob/keras-2/examples/mnist_mlp.py | |
# and https://qiita.com/hiroeorz@github/items/ecb39ed4042ebdc0a957 | |
# and http://www.mathgram.xyz/entry/chainer/bake/part3 | |
# and https://qiita.com/haru1977/items/17833e508fe07c004119 | |
""" | |
以下のようなフォルダ、ファイル構造とする。 | |
data/ | |
0/ | |
dog001.jpg | |
dog002.jpg | |
... | |
1/ | |
cat001.jpg | |
cat002.jpg | |
... | |
... | |
9/ | |
cow001.jpg | |
cow002.jpg | |
... | |
""" | |
#1 Kerasを使用するためのimport文 | |
import keras | |
from keras.datasets import mnist | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout | |
from keras.optimizers import RMSprop | |
from keras.utils import np_utils | |
from sklearn.cross_validation import train_test_split | |
# from sklearn.model_selection import train_test_split が望ましいらしい | |
import numpy as np | |
from PIL import Image | |
import os | |
#2 データ準備(Keras) | |
# 学習用のデータを作る. | |
image_list = [] | |
label_list = [] | |
# ./data/train 以下のorange,appleディレクトリ以下の画像を読み込む。 | |
for dir in os.listdir("data"): | |
if dir == ".DS_Store": | |
continue | |
dir1 = "data/" + dir | |
# フォルダ"0"のラベルは"0"、フォルダ"1"のラベルは"1", ... , フォルダ"9"のラベルは"9" | |
label = dir | |
for file in os.listdir(dir1): | |
if file != "Thumbs.db": | |
# Macだと、if file != ".DS_Store": なのかもしれない。。。 | |
# 配列label_listに正解ラベルを追加 | |
label_list.append(label) | |
filepath = dir1 + "/" + file | |
# 画像を読み込み、グレースケールに変換し、28x28pixelに変換し、numpy配列へ変換する。 | |
# 画像の1ピクセルは、それぞれが0-255の数値。 | |
image = np.array(Image.open(filepath).convert("L").resize((28, 28))) | |
# print(filepath) | |
# さらにフラットな1次元配列に変換。 | |
image = image.reshape(1, 784).astype("float32")[0] | |
# 出来上がった配列をimage_listに追加。 | |
image_list.append(image / 255.) | |
# kerasに渡すためにnumpy配列に変換。 | |
image_list = np.array(image_list) | |
label_list = np.array(label_list) | |
# クラスの形式を変換 | |
label_list = np_utils.to_categorical(label_list, 10) | |
# 学習用データとテストデータ | |
X_train, X_test, y_train, y_test = train_test_split(image_list, label_list, test_size=0.33, random_state=111) | |
print(X_train.shape[0], 'train samples') | |
print(X_test.shape[0], 'test samples') | |
#3 モデル設定(Keras) | |
batch_size = 128 | |
num_classes = 10 | |
# epochs = 20 | |
epochs = 3 | |
""" | |
# the data, shuffled and split between train and test sets | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train = x_train.reshape(60000, 784) | |
x_test = x_test.reshape(10000, 784) | |
x_train = x_train.astype('float32') | |
x_test = x_test.astype('float32') | |
x_train /= 255 | |
x_test /= 255 | |
print(x_train.shape[0], 'train samples') | |
print(x_test.shape[0], 'test samples') | |
""" | |
""" | |
# convert class vectors to binary class matrices | |
y_train = keras.utils.to_categorical(y_train, num_classes) | |
y_test = keras.utils.to_categorical(y_test, num_classes) | |
""" | |
#3 モデル設定(Keras) | |
model = Sequential() | |
model.add(Dense(512, activation='relu', input_shape=(784,))) | |
model.add(Dropout(0.2)) | |
model.add(Dense(512, activation='relu')) | |
model.add(Dropout(0.2)) | |
model.add(Dense(10, activation='softmax')) | |
model.summary() | |
model.compile(loss='categorical_crossentropy', | |
optimizer=RMSprop(), | |
metrics=['accuracy']) | |
#4 モデル学習(Keras) | |
history = model.fit(X_train, y_train, | |
batch_size=batch_size, epochs=epochs, | |
verbose=1, validation_data=(X_test, y_test)) | |
#5 結果の出力(Keras) | |
score = model.evaluate(X_test, y_test, verbose=0) | |
print('Test loss:', score[0]) | |
#5 結果の出力(Keras) | |
print('Test accuracy:', score[1]) | |
#6 学習結果の保存(Keras) | |
### save model and weights | |
json_string = model.to_json() | |
open('apple_orange_model.json', 'w').write(json_string) | |
model.save_weights('apple_orange_weights.h5') | |
# predict.py | |
# Keras2_MNIST_MLP_predict | |
#7 推測(Keras) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment