Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created August 20, 2018 19:46
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/99a6e48430f0ab0d7bb9e90d2b90353c to your computer and use it in GitHub Desktop.
Save koshian2/99a6e48430f0ab0d7bb9e90d2b90353c to your computer and use it in GitHub Desktop.
DenseNet CIFAR10 in Keras
from keras.layers import Conv2D, Activation, BatchNormalization, Concatenate, AveragePooling2D, Input, GlobalAveragePooling2D, Dense
from keras.models import Model
from keras.optimizers import Adam, SGD
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import Callback, LearningRateScheduler
from keras.regularizers import l2
import pickle
import numpy as np
import time
# 経過時間用のコールバック
class TimeHistory(Callback):
# https://stackoverflow.com/questions/43178668/record-the-computation-time-for-each-epoch-in-keras-during-model-fit
def on_train_begin(self, logs={}):
self.times = []
def on_epoch_begin(self, batch, logs={}):
self.epoch_time_start = time.time()
def on_epoch_end(self, batch, logs={}):
self.times.append(time.time() - self.epoch_time_start)
class DenseNetSimple:
def __init__(self, growth_rate, compression_factor=0.5, blocks=[1,2,4,3]):
# 成長率(growth_rate):DenseBlockで増やすフィルターの数
self.k = growth_rate
# 圧縮率(compression_factor):Transitionレイヤーで圧縮するフィルターの比
self.compression = compression_factor
# 正則化
self.weight_decay = 2e-4
# モデルの作成
self.model = self.make_model(blocks)
# 経過時間
self.elapsed = []
# DenseBlockのLayer
def dense_block(self, input_tensor, input_channels, nb_blocks):
x = input_tensor
n_channels = input_channels
for i in range(nb_blocks):
# 分岐前の本線
main = x
# DenseBlock側の分岐
x = BatchNormalization()(x)
x = Activation("relu")(x)
# Bottle-Neck 1x1畳み込み
x = Conv2D(128, (1, 1), kernel_regularizer=l2(self.weight_decay))(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
# 3x3畳み込み フィルター数は成長率
x = Conv2D(self.k, (3, 3), padding="same", kernel_regularizer=l2(self.weight_decay))(x)
# 本線と結合
x = Concatenate()([main, x])
n_channels += self.k
return x, n_channels
# Transition Layer
def transition_layer(self, input_tensor, input_channels):
n_channels = int(input_channels * self.compression)
# 1x1畳み込みで圧縮
x = Conv2D(n_channels, (1, 1), kernel_regularizer=l2(self.weight_decay))(input_tensor)
# AveragePooling
x = AveragePooling2D((2, 2))(x)
return x, n_channels
# 学習率減衰
def learning_decay(self, epoch):
x = 0.001
if epoch >= self.nb_epochs * 0.5: x /= 10.0
if epoch >= self.nb_epochs * 0.75: x /= 10.0
return x
# モデルの作成
def make_model(self, blocks):
# blocks=[6,12,24,16]とするとDenseNet-121の設定に準じる
input = Input(shape=(32,32,3))
# 端数を出さないようにフィルター数16にする
n = 16
x = Conv2D(n, (1,1))(input)
# DenseBlock - TransitionLayer - DenseBlock…
for i in range(len(blocks)):
# Transition
if i != 0:
x, n = self.transition_layer(x, n)
# DenseBlock
x, n = self.dense_block(x, n, blocks[i])
# GlobalAveragePooling(チャンネル単位の全平均)
x = GlobalAveragePooling2D()(x)
# 出力層
output = Dense(10, activation="softmax")(x)
# モデル
model = Model(input, output)
return model
# 訓練
def train(self, X_train, y_train, X_val, y_val, nb_epochs):
# コンパイル
self.model.compile(optimizer=Adam(lr=0.001), loss="categorical_crossentropy", metrics=["acc"])
# Data Augmentation
datagen = ImageDataGenerator(
rescale=1./255,
samplewise_center=True,
samplewise_std_normalization=True,
width_shift_range=4./32,
height_shift_range=4./32,
horizontal_flip=True)
valgen = ImageDataGenerator(
rescale=1./255,
samplewise_center=True,
samplewise_std_normalization=True)
datagen.fit(X_train)
valgen.fit(X_val)
# 訓練
self.nb_epochs = nb_epochs
# コールバック
time_callback = TimeHistory()
lr_decay = LearningRateScheduler(self.learning_decay)
# 水増しありの訓練
history = self.model.fit_generator(datagen.flow(X_train, y_train, batch_size=128),
steps_per_epoch=len(X_train) / 128, validation_data=valgen.flow(X_val, y_val), epochs=nb_epochs,
callbacks=[time_callback, lr_decay]).history
# 保存
with open("history.dat", "wb") as fp:
pickle.dump(history, fp)
with open("elapsed.dat", "wb") as fp:
pickle.dump(time_callback.times, fp)
if __name__ == "__main__":
# k=16の場合
densenet = DenseNetSimple(16, blocks=[6,12,24,16])
densenet.model.summary()
# CIFAR-10の読み込み
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_test = X_test / 255.0
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
densenet.train(X_train, y_train, X_test, y_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment