Created
August 20, 2018 19:46
-
-
Save koshian2/99a6e48430f0ab0d7bb9e90d2b90353c to your computer and use it in GitHub Desktop.
DenseNet CIFAR10 in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Conv2D, Activation, BatchNormalization, Concatenate, AveragePooling2D, Input, GlobalAveragePooling2D, Dense | |
from keras.models import Model | |
from keras.optimizers import Adam, SGD | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.callbacks import Callback, LearningRateScheduler | |
from keras.regularizers import l2 | |
import pickle | |
import numpy as np | |
import time | |
# 経過時間用のコールバック | |
class TimeHistory(Callback): | |
# https://stackoverflow.com/questions/43178668/record-the-computation-time-for-each-epoch-in-keras-during-model-fit | |
def on_train_begin(self, logs={}): | |
self.times = [] | |
def on_epoch_begin(self, batch, logs={}): | |
self.epoch_time_start = time.time() | |
def on_epoch_end(self, batch, logs={}): | |
self.times.append(time.time() - self.epoch_time_start) | |
class DenseNetSimple: | |
def __init__(self, growth_rate, compression_factor=0.5, blocks=[1,2,4,3]): | |
# 成長率(growth_rate):DenseBlockで増やすフィルターの数 | |
self.k = growth_rate | |
# 圧縮率(compression_factor):Transitionレイヤーで圧縮するフィルターの比 | |
self.compression = compression_factor | |
# 正則化 | |
self.weight_decay = 2e-4 | |
# モデルの作成 | |
self.model = self.make_model(blocks) | |
# 経過時間 | |
self.elapsed = [] | |
# DenseBlockのLayer | |
def dense_block(self, input_tensor, input_channels, nb_blocks): | |
x = input_tensor | |
n_channels = input_channels | |
for i in range(nb_blocks): | |
# 分岐前の本線 | |
main = x | |
# DenseBlock側の分岐 | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
# Bottle-Neck 1x1畳み込み | |
x = Conv2D(128, (1, 1), kernel_regularizer=l2(self.weight_decay))(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
# 3x3畳み込み フィルター数は成長率 | |
x = Conv2D(self.k, (3, 3), padding="same", kernel_regularizer=l2(self.weight_decay))(x) | |
# 本線と結合 | |
x = Concatenate()([main, x]) | |
n_channels += self.k | |
return x, n_channels | |
# Transition Layer | |
def transition_layer(self, input_tensor, input_channels): | |
n_channels = int(input_channels * self.compression) | |
# 1x1畳み込みで圧縮 | |
x = Conv2D(n_channels, (1, 1), kernel_regularizer=l2(self.weight_decay))(input_tensor) | |
# AveragePooling | |
x = AveragePooling2D((2, 2))(x) | |
return x, n_channels | |
# 学習率減衰 | |
def learning_decay(self, epoch): | |
x = 0.001 | |
if epoch >= self.nb_epochs * 0.5: x /= 10.0 | |
if epoch >= self.nb_epochs * 0.75: x /= 10.0 | |
return x | |
# モデルの作成 | |
def make_model(self, blocks): | |
# blocks=[6,12,24,16]とするとDenseNet-121の設定に準じる | |
input = Input(shape=(32,32,3)) | |
# 端数を出さないようにフィルター数16にする | |
n = 16 | |
x = Conv2D(n, (1,1))(input) | |
# DenseBlock - TransitionLayer - DenseBlock… | |
for i in range(len(blocks)): | |
# Transition | |
if i != 0: | |
x, n = self.transition_layer(x, n) | |
# DenseBlock | |
x, n = self.dense_block(x, n, blocks[i]) | |
# GlobalAveragePooling(チャンネル単位の全平均) | |
x = GlobalAveragePooling2D()(x) | |
# 出力層 | |
output = Dense(10, activation="softmax")(x) | |
# モデル | |
model = Model(input, output) | |
return model | |
# 訓練 | |
def train(self, X_train, y_train, X_val, y_val, nb_epochs): | |
# コンパイル | |
self.model.compile(optimizer=Adam(lr=0.001), loss="categorical_crossentropy", metrics=["acc"]) | |
# Data Augmentation | |
datagen = ImageDataGenerator( | |
rescale=1./255, | |
samplewise_center=True, | |
samplewise_std_normalization=True, | |
width_shift_range=4./32, | |
height_shift_range=4./32, | |
horizontal_flip=True) | |
valgen = ImageDataGenerator( | |
rescale=1./255, | |
samplewise_center=True, | |
samplewise_std_normalization=True) | |
datagen.fit(X_train) | |
valgen.fit(X_val) | |
# 訓練 | |
self.nb_epochs = nb_epochs | |
# コールバック | |
time_callback = TimeHistory() | |
lr_decay = LearningRateScheduler(self.learning_decay) | |
# 水増しありの訓練 | |
history = self.model.fit_generator(datagen.flow(X_train, y_train, batch_size=128), | |
steps_per_epoch=len(X_train) / 128, validation_data=valgen.flow(X_val, y_val), epochs=nb_epochs, | |
callbacks=[time_callback, lr_decay]).history | |
# 保存 | |
with open("history.dat", "wb") as fp: | |
pickle.dump(history, fp) | |
with open("elapsed.dat", "wb") as fp: | |
pickle.dump(time_callback.times, fp) | |
if __name__ == "__main__": | |
# k=16の場合 | |
densenet = DenseNetSimple(16, blocks=[6,12,24,16]) | |
densenet.model.summary() | |
# CIFAR-10の読み込み | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
X_test = X_test / 255.0 | |
y_train, y_test = to_categorical(y_train), to_categorical(y_test) | |
densenet.train(X_train, y_train, X_test, y_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment