Created
November 29, 2018 00:14
Insert batchnorm to vgg16
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.applications import VGG16 | |
from tensorflow.keras.layers import BatchNormalization, Activation, GlobalAveragePooling2D, Dense | |
from tensorflow.keras.models import Model | |
import tensorflow.keras.activations as activations | |
from tensorflow.keras.callbacks import History | |
from tensorflow.contrib.tpu.python.tpu import keras_support | |
import tensorflow.keras.backend as K | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
import numpy as np | |
import pickle, os | |
def create_normal_model(): | |
model = VGG16(include_top=False, input_shape=(64,64,3), weights="imagenet") | |
x = GlobalAveragePooling2D()(model.layers[-1].output) | |
x = Dense(10, activation="softmax")(x) | |
# あとでBatchNormを入れるため係数の固定はしない。初期値設定のみ転移学習とする | |
return Model(model.inputs, x) | |
def create_batch_norm_model(): | |
model = create_normal_model() | |
for i, layer in enumerate(model.layers): | |
if i==0: | |
input = layer.input | |
x = input | |
else: | |
if "conv" in layer.name: | |
layer.activation = activations.linear | |
x = layer(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
else: | |
x = layer(x) | |
bn_model = Model(input, x) | |
return bn_model | |
def generator(X, y, batch_size): | |
# 32->64に引き伸ばす操作でメモリ食い過ぎるので自分でジェネレーター書く | |
while True: | |
indices = np.arange(X.shape[0]) | |
np.random.shuffle(indices) | |
for i in range(X.shape[0]//batch_size): | |
current_indices = indices[i*batch_size:(i+1)*batch_size] | |
X_batch = X[current_indices] | |
# 2倍に引き伸ばす | |
X_batch = X_batch.repeat(2, axis=1).repeat(2, axis=2) | |
X_batch = X_batch / 255.0 | |
y_batch = to_categorical(y[current_indices], 10) | |
yield X_batch, y_batch | |
def train(use_batch_norm): | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
if use_batch_norm: | |
model = create_batch_norm_model() | |
else: | |
model = create_normal_model() | |
# 元の係数を壊さないように低めの学習率を使う | |
model.compile(tf.train.RMSPropOptimizer(1e-5), "categorical_crossentropy", ["acc"]) | |
model.summary() | |
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) | |
hist = History() | |
batch_size = 1024 | |
model.fit_generator(generator(X_train, y_train, batch_size), steps_per_epoch=X_train.shape[0]//batch_size, | |
validation_data=generator(X_test, y_test, batch_size), | |
validation_steps=X_test.shape[0]//batch_size, callbacks=[hist], | |
epochs=100) | |
with open(f"history_bn_{use_batch_norm}.dat", "wb") as fp: | |
pickle.dump(hist.history, fp) | |
if __name__ == "__main__": | |
K.clear_session() | |
train(True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment