Skip to content

Instantly share code, notes, and snippets.

Created August 4, 2018 12:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/70fe027d789c2181e7f9127924afa1af to your computer and use it in GitHub Desktop.
Save koshian2/70fe027d789c2181e7f9127924afa1af to your computer and use it in GitHub Desktop.
Simple DenseNet with CIFAR-10
from keras.layers import Conv2D, Activation, BatchNormalization, Concatenate, AveragePooling2D, Input, GlobalAveragePooling2D, Dense
from keras.models import Model
from keras.optimizers import Adam
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import Callback
import pickle
import numpy as np
import time
# 経過時間用のコールバック
class TimeHistory(Callback):
def on_train_begin(self, logs={}):
self.times = []
def on_epoch_begin(self, batch, logs={}):
self.epoch_time_start = time.time()
def on_epoch_end(self, batch, logs={}):
self.times.append(time.time() - self.epoch_time_start)
class DenseNetSimple:
def __init__(self, growth_rate, compression_factor=0.5, blocks=[1,2,4,3]):
# 成長率(growth_rate):DenseBlockで増やすフィルターの数
self.k = growth_rate
# 圧縮率(compression_factor):Transitionレイヤーで圧縮するフィルターの比
self.compression = compression_factor
# モデルの作成
self.model = self.make_model(blocks)
# 経過時間
self.elapsed = []
# DenseBlockのLayer
def dense_block(self, input_tensor, input_channels, nb_blocks):
x = input_tensor
n_channels = input_channels
for i in range(nb_blocks):
# 分岐前の本線
main = x
# DenseBlock側の分岐
x = BatchNormalization()(x)
x = Activation("relu")(x)
# Bottle-Neck 1x1畳み込み
x = Conv2D(128, (1, 1))(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
# 3x3畳み込み フィルター数は成長率
x = Conv2D(self.k, (3, 3), padding="same")(x)
# 本線と結合
x = Concatenate()([main, x])
n_channels += self.k
return x, n_channels
# Transition Layer
def transition_layer(self, input_tensor, input_channels):
n_channels = int(input_channels * self.compression)
# 1x1畳み込みで圧縮
x = Conv2D(n_channels, (1, 1))(input_tensor)
# AveragePooling
x = AveragePooling2D((2, 2))(x)
return x, n_channels
# モデルの作成
def make_model(self, blocks):
# blocks=[6,12,24,16]とするとDenseNet-121の設定に準じる
input = Input(shape=(32,32,3))
# 端数を出さないようにフィルター数16にする
n = 16
x = Conv2D(n, (1,1))(input)
# DenseBlock - TransitionLayer - DenseBlock…
for i in range(len(blocks)):
# Transition
if i != 0:
x, n = self.transition_layer(x, n)
# DenseBlock
x, n = self.dense_block(x, n, blocks[i])
# GlobalAveragePooling(チャンネル単位の全平均)
x = GlobalAveragePooling2D()(x)
# 出力層
output = Dense(10, activation="softmax")(x)
# モデル
model = Model(input, output)
return model
# 訓練
def train(self, X_train, y_train, X_val, y_val):
# コンパイル
self.model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["acc"])
# Data Augmentation
datagen = ImageDataGenerator(
# 訓練
#history =, y_train, batch_size=128, epochs=1, validation_data=(X_val, y_val)).history
# コールバック
time_callback = TimeHistory()
# 水増しありの訓練
history = self.model.fit_generator(datagen.flow(X_train, y_train, batch_size=128),
steps_per_epoch=len(X_train) / 128, validation_data=(X_val, y_val), epochs=1, callbacks=[time_callback]).history
self.time = time_callback.times
# 保存
with open("history.dat", "wb") as fp:
pickle.dump(history, fp)
if __name__ == "__main__":
# k=16の場合
#densenet = DenseNetSimple(16, blocks=[6,12,24,16])
# CIFAR-10の読み込み
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
# X_train = (X_train / 255.0).astype("float32")
X_test = (X_test / 255.0).astype("float32")
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
#densenet.train(X_train, y_train, X_test, y_test)
# 処理時間測定
params = [[16, [1,2,4,3]],
[32, [1,2,4,3]],
[16, [2,4,8,5]],
[16, [3,6,12,8]],
[16, [6,12,24,16]]]
elapsed = []
for p in params:
model = DenseNetSimple(growth_rate=p[0], blocks=p[1])
model.train(X_train, y_train, X_test, y_test)
with open("elapsed.dat", "wb") as fp:
pickle.dump(elapsed, fp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment