Created
June 20, 2018 12:12
-
-
Save koshian2/d658d9ae18600aaef2ea1891e23c047b to your computer and use it in GitHub Desktop.
BatchNorm with Fashion-MNIST
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import pickle | |
import time | |
from keras.datasets import fashion_mnist | |
from keras.models import Model | |
from keras.layers import Input, Activation, Conv2D, BatchNormalization, Flatten, Dense | |
from keras.optimizers import Adam | |
from keras.utils import to_categorical | |
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data() | |
# 標準化 | |
X_train, X_test = X_train / 255.0, X_test / 255.0 | |
# ランク4にする | |
X_train, X_test = X_train[:, :, :, np.newaxis], X_test[:, :, :, np.newaxis] | |
# One-hotベクトル化 | |
y_train, y_test = to_categorical(y_train), to_categorical(y_test) | |
# モデルの作成 | |
def create_model(nb_layers, use_bn, bn_freq=1): | |
# 入力層 | |
input = Input(shape=(28, 28, 1)) | |
x = input | |
# 隠れ層 | |
for i in range(nb_layers): | |
x = Conv2D(filters=2*(i+1), kernel_size=(3, 3))(x) | |
if use_bn and ((i+1)%bn_freq==0): | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
# FC | |
x = Flatten()(x) | |
# 出力層 | |
y = Dense(10, activation="softmax")(x) | |
# モデル | |
model = Model(inputs=input, outputs=y) | |
return model | |
# 定数 | |
nb_epoch = 20 | |
# 結果の保存用 | |
result = [] | |
times = [] | |
# レイヤー数6、BN有りの計7 | |
for i in range(7): | |
print(f"★★ i = {i}★★") | |
# モデルの取得 | |
model = create_model(6, True, i+1) | |
# コンパイル | |
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"]) | |
# フィット | |
start_time = time.time() | |
history = model.fit(X_train, y_train, batch_size=64, | |
epochs=nb_epoch, validation_data=(X_test, y_test)).history | |
elapsed = time.time() - start_time | |
# 結果に追加 | |
result.append(history) | |
times.append(elapsed) | |
# 保存 | |
with open("result_fq.dat", "wb") as fp: | |
pickle.dump(result, fp) | |
with open("times_fq.dat", "wb") as fp: | |
pickle.dump(times, fp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment