Created
August 4, 2018 12:30
-
-
Save koshian2/70fe027d789c2181e7f9127924afa1af to your computer and use it in GitHub Desktop.
Simple DenseNet with CIFAR-10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Conv2D, Activation, BatchNormalization, Concatenate, AveragePooling2D, Input, GlobalAveragePooling2D, Dense | |
from keras.models import Model | |
from keras.optimizers import Adam | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.callbacks import Callback | |
import pickle | |
import numpy as np | |
import time | |
# 経過時間用のコールバック | |
class TimeHistory(Callback): | |
# https://stackoverflow.com/questions/43178668/record-the-computation-time-for-each-epoch-in-keras-during-model-fit | |
def on_train_begin(self, logs={}): | |
self.times = [] | |
def on_epoch_begin(self, batch, logs={}): | |
self.epoch_time_start = time.time() | |
def on_epoch_end(self, batch, logs={}): | |
self.times.append(time.time() - self.epoch_time_start) | |
class DenseNetSimple: | |
def __init__(self, growth_rate, compression_factor=0.5, blocks=[1,2,4,3]): | |
# 成長率(growth_rate):DenseBlockで増やすフィルターの数 | |
self.k = growth_rate | |
# 圧縮率(compression_factor):Transitionレイヤーで圧縮するフィルターの比 | |
self.compression = compression_factor | |
# モデルの作成 | |
self.model = self.make_model(blocks) | |
# 経過時間 | |
self.elapsed = [] | |
# DenseBlockのLayer | |
def dense_block(self, input_tensor, input_channels, nb_blocks): | |
x = input_tensor | |
n_channels = input_channels | |
for i in range(nb_blocks): | |
# 分岐前の本線 | |
main = x | |
# DenseBlock側の分岐 | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
# Bottle-Neck 1x1畳み込み | |
x = Conv2D(128, (1, 1))(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
# 3x3畳み込み フィルター数は成長率 | |
x = Conv2D(self.k, (3, 3), padding="same")(x) | |
# 本線と結合 | |
x = Concatenate()([main, x]) | |
n_channels += self.k | |
return x, n_channels | |
# Transition Layer | |
def transition_layer(self, input_tensor, input_channels): | |
n_channels = int(input_channels * self.compression) | |
# 1x1畳み込みで圧縮 | |
x = Conv2D(n_channels, (1, 1))(input_tensor) | |
# AveragePooling | |
x = AveragePooling2D((2, 2))(x) | |
return x, n_channels | |
# モデルの作成 | |
def make_model(self, blocks): | |
# blocks=[6,12,24,16]とするとDenseNet-121の設定に準じる | |
input = Input(shape=(32,32,3)) | |
# 端数を出さないようにフィルター数16にする | |
n = 16 | |
x = Conv2D(n, (1,1))(input) | |
# DenseBlock - TransitionLayer - DenseBlock… | |
for i in range(len(blocks)): | |
# Transition | |
if i != 0: | |
x, n = self.transition_layer(x, n) | |
# DenseBlock | |
x, n = self.dense_block(x, n, blocks[i]) | |
# GlobalAveragePooling(チャンネル単位の全平均) | |
x = GlobalAveragePooling2D()(x) | |
# 出力層 | |
output = Dense(10, activation="softmax")(x) | |
# モデル | |
model = Model(input, output) | |
return model | |
# 訓練 | |
def train(self, X_train, y_train, X_val, y_val): | |
# コンパイル | |
self.model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["acc"]) | |
# Data Augmentation | |
datagen = ImageDataGenerator( | |
rescale=1./255, | |
rotation_range=20, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
channel_shift_range=50, | |
horizontal_flip=True) | |
# 訓練 | |
#history = self.model.fit(X_train, y_train, batch_size=128, epochs=1, validation_data=(X_val, y_val)).history | |
# コールバック | |
time_callback = TimeHistory() | |
# 水増しありの訓練 | |
history = self.model.fit_generator(datagen.flow(X_train, y_train, batch_size=128), | |
steps_per_epoch=len(X_train) / 128, validation_data=(X_val, y_val), epochs=1, callbacks=[time_callback]).history | |
self.time = time_callback.times | |
# 保存 | |
with open("history.dat", "wb") as fp: | |
pickle.dump(history, fp) | |
if __name__ == "__main__": | |
# k=16の場合 | |
#densenet = DenseNetSimple(16, blocks=[6,12,24,16]) | |
#densenet.model.summary() | |
# CIFAR-10の読み込み | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
# X_train = (X_train / 255.0).astype("float32") | |
X_test = (X_test / 255.0).astype("float32") | |
y_train, y_test = to_categorical(y_train), to_categorical(y_test) | |
#densenet.train(X_train, y_train, X_test, y_test) | |
# 処理時間測定 | |
params = [[16, [1,2,4,3]], | |
[32, [1,2,4,3]], | |
[16, [2,4,8,5]], | |
[16, [3,6,12,8]], | |
[16, [6,12,24,16]]] | |
elapsed = [] | |
for p in params: | |
model = DenseNetSimple(growth_rate=p[0], blocks=p[1]) | |
model.train(X_train, y_train, X_test, y_test) | |
elapsed.append(model.time) | |
with open("elapsed.dat", "wb") as fp: | |
pickle.dump(elapsed, fp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment