Last active
June 4, 2019 14:32
-
-
Save koshian2/6bf021b7fc012a96e4138acc31aa28f6 to your computer and use it in GitHub Desktop.
Conv2D, DepthwiseConv2D, SeparableConv2D compare on CIFAR-10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Conv2D, DepthwiseConv2D, SeparableConv2D, BatchNormalization, Activation, Input, GlobalAveragePooling2D, AveragePooling2D, Dense | |
from keras.models import Model | |
from keras.callbacks import Callback | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
import time, pickle | |
def create_block(mode, input, ch): | |
if mode == 0: | |
x = Conv2D(ch, 3, padding="same")(input) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = Conv2D(ch, 3, padding="same")(x) | |
x = BatchNormalization()(x) | |
return Activation("relu")(x) | |
elif mode == 1: | |
x = DepthwiseConv2D(3, padding="same")(input) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = DepthwiseConv2D(3, padding="same")(x) | |
x = BatchNormalization()(x) | |
return Activation("relu")(x) | |
elif mode == 2: | |
x = SeparableConv2D(ch, 3, padding="same")(input) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = SeparableConv2D(ch, 3, padding="same")(x) | |
x = BatchNormalization()(x) | |
return Activation("relu")(x) | |
def create_adjusting_block(input, ch): | |
x = Conv2D(ch, 1)(input) | |
x = BatchNormalization()(x) | |
return Activation("relu")(x) | |
def create_model(mode): | |
input = Input((32,32,3)) | |
x = create_adjusting_block(input, 64) # Depthwiseの調整用 | |
x = create_block(mode, x, 64) | |
x = AveragePooling2D(2)(x) | |
x = create_adjusting_block(x, 128) # Depthwiseの調整用 | |
x = create_block(mode, x, 128) | |
x = AveragePooling2D(2)(x) | |
x = create_adjusting_block(x, 256) # Depthwiseの調整用 | |
x = create_block(mode, x, 256) | |
x = GlobalAveragePooling2D()(x) | |
x = Dense(10, activation="softmax")(x) | |
model = Model(input, x) | |
return model | |
class Timer(Callback): | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.time_history = [] | |
self.start_time = time.time() | |
def on_train_begin(self, logs): | |
self.reset() | |
def on_epoch_end(self, epoch, logs): | |
self.time_history.append(time.time() - self.start_time) | |
self.start_time = time.time() | |
def train(mode): | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
X_train = X_train / 255.0 | |
X_test = X_test / 255.0 | |
y_train = to_categorical(y_train) | |
y_test = to_categorical(y_test) | |
model = create_model(mode) | |
model.compile("adam", "categorical_crossentropy", ["acc"]) | |
timer = Timer() | |
history = model.fit(X_train, y_train, batch_size=128, validation_data=(X_test, y_test), callbacks=[timer], epochs=100).history | |
history["times"] = timer.time_history | |
return history | |
def train_convs(): | |
result = {} | |
for i, tag in enumerate(["conv", "depthwise", "separable"]): | |
result[tag] = train(i) | |
with open("history_conv_compare.dat", "wb") as fp: | |
pickle.dump(result, fp) | |
if __name__ == "__main__": | |
train_convs() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment