Created
June 20, 2018 09:42
-
-
Save koshian2/de97a35e4e4f4860c86f90d0620fc76a to your computer and use it in GitHub Desktop.
BatchNorm with FashionMINIST
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import pickle | |
import time | |
from keras.datasets import fashion_mnist | |
from keras.models import Model | |
from keras.layers import Input, Activation, Conv2D, BatchNormalization, Flatten, Dense | |
from keras.optimizers import Adam | |
from keras.utils import to_categorical | |
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data() | |
# 標準化 | |
X_train, X_test = X_train / 255.0, X_test / 255.0 | |
# ランク4にする | |
X_train, X_test = X_train[:, :, :, np.newaxis], X_test[:, :, :, np.newaxis] | |
# One-hotベクトル化 | |
y_train, y_test = to_categorical(y_train), to_categorical(y_test) | |
# モデルの作成 | |
def create_model(nb_layers, use_bn, bn_freq=1): | |
# 入力層 | |
input = Input(shape=(28, 28, 1)) | |
x = input | |
# 隠れ層 | |
for i in range(nb_layers): | |
x = Conv2D(filters=2*(i+1), kernel_size=(3, 3))(x) | |
if use_bn and ((i+1)%bn_freq==0): | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
# FC | |
x = Flatten()(x) | |
# 出力層 | |
y = Dense(10, activation="softmax")(x) | |
# モデル | |
model = Model(inputs=input, outputs=y) | |
return model | |
# 定数 | |
nb_epoch = 20 | |
# 結果の保存用 | |
result = [] | |
times = [] | |
# レイヤー数10、BN有りなしの計20 | |
for i in range(20): | |
nb_layers = int(i/2) + 1 | |
use_bn = i % 2 == 1 | |
print(f"★★ i = {i}, #layers = {nb_layers}, bn = {use_bn} ★★") | |
# モデルの取得 | |
model = create_model(nb_layers, use_bn) | |
# コンパイル | |
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"]) | |
# フィット | |
start_time = time.time() | |
history = model.fit(X_train, y_train, batch_size=64, | |
epochs=nb_epoch, validation_data=(X_test, y_test)).history | |
elapsed = time.time() - start_time | |
# 結果に追加 | |
result.append(history) | |
times.append(elapsed) | |
# 保存 | |
with open("result.dat", "wb") as fp: | |
pickle.dump(result, fp) | |
with open("times.dat", "wb") as fp: | |
pickle.dump(times, fp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment