'''
Trains a simple deep NN on the MNIST dataset.
Gets to 98.40% test accuracy after 20 epochs
(there is *a lot* of margin for parameter tuning).
2 seconds per epoch on a K520 GPU.
'''

batch_size = 128
num_classes = 10 # 手書き数字を10クラスの数字に分類
epochs = 3 # 20にすると精度は98.4%に

# データをランダムに教師データとテストデータに分ける
# xが画像データ(28x28ピクセルの行列, uint8[28][28]の配列)
# yはカテゴリラベルデータ(0から9のinteger, uint8の配列)
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 画像データ(訓練データ)を少し表示してみる
fig=plt.figure(figsize=(10, 5))
fig.subplots_adjust(left=0, bottom=0, right=0.5, top=0.5, wspace=0.1, hspace=0.1)
for i in range(50):
    ax=fig.add_subplot(5, 10, i+1, xticks=[], yticks=[])
    ax.imshow(x_train[i].reshape((28, 28)), cmap='gray')
plt.show()

# 28x28=784要素の2次元データを, 要素数784個の1次元データに変換
# (2次元データのまま処理をすすめる方法もある)
# 訓練データは60,000個。テストデータは10,000個
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)

# データをfloat32型に変換し,  [0,255]の値を[0,1.0]に規格化
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

# 各サンプルの個数を確認
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# カテゴリの値[0,9]をバイナリ値(10bitのいずれかのみが1)に変換
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# 入力層(784)=>中間層1(512)=>中間層2(512)=>出力層(10)
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))

# モデルの情報を表示
model.summary()

# NNモデル構築
# - loss: 目的関数
# - optimizer: 学習方法
#   - RMSpropは勾配降下法の修正版。これを使うにはヘッダ部で宣言必要
# - metrics: モデルの正確性の評価に使う指標(学習には使わないが履歴を残せる)
model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])

# 学習
hist = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test))

# 評価
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# 学習経過のグラフ表示
hist_plot(hist)