Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created June 20, 2018 12:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/d658d9ae18600aaef2ea1891e23c047b to your computer and use it in GitHub Desktop.
Save koshian2/d658d9ae18600aaef2ea1891e23c047b to your computer and use it in GitHub Desktop.
BatchNorm with Fashion-MNIST
import numpy as np
import matplotlib.pyplot as plt
import pickle
import time
from keras.datasets import fashion_mnist
from keras.models import Model
from keras.layers import Input, Activation, Conv2D, BatchNormalization, Flatten, Dense
from keras.optimizers import Adam
from keras.utils import to_categorical
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
# 標準化
X_train, X_test = X_train / 255.0, X_test / 255.0
# ランク4にする
X_train, X_test = X_train[:, :, :, np.newaxis], X_test[:, :, :, np.newaxis]
# One-hotベクトル化
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
# モデルの作成
def create_model(nb_layers, use_bn, bn_freq=1):
# 入力層
input = Input(shape=(28, 28, 1))
x = input
# 隠れ層
for i in range(nb_layers):
x = Conv2D(filters=2*(i+1), kernel_size=(3, 3))(x)
if use_bn and ((i+1)%bn_freq==0):
x = BatchNormalization()(x)
x = Activation("relu")(x)
# FC
x = Flatten()(x)
# 出力層
y = Dense(10, activation="softmax")(x)
# モデル
model = Model(inputs=input, outputs=y)
return model
# 定数
nb_epoch = 20
# 結果の保存用
result = []
times = []
# レイヤー数6、BN有りの計7
for i in range(7):
print(f"★★ i = {i}★★")
# モデルの取得
model = create_model(6, True, i+1)
# コンパイル
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"])
# フィット
start_time = time.time()
history = model.fit(X_train, y_train, batch_size=64,
epochs=nb_epoch, validation_data=(X_test, y_test)).history
elapsed = time.time() - start_time
# 結果に追加
result.append(history)
times.append(elapsed)
# 保存
with open("result_fq.dat", "wb") as fp:
pickle.dump(result, fp)
with open("times_fq.dat", "wb") as fp:
pickle.dump(times, fp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment