Skip to content

Instantly share code, notes, and snippets.

@koshian2
Last active June 4, 2019 14:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save koshian2/6bf021b7fc012a96e4138acc31aa28f6 to your computer and use it in GitHub Desktop.
Save koshian2/6bf021b7fc012a96e4138acc31aa28f6 to your computer and use it in GitHub Desktop.
Conv2D, DepthwiseConv2D, SeparableConv2D compare on CIFAR-10
from keras.layers import Conv2D, DepthwiseConv2D, SeparableConv2D, BatchNormalization, Activation, Input, GlobalAveragePooling2D, AveragePooling2D, Dense
from keras.models import Model
from keras.callbacks import Callback
from keras.datasets import cifar10
from keras.utils import to_categorical
import time, pickle
def create_block(mode, input, ch):
if mode == 0:
x = Conv2D(ch, 3, padding="same")(input)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(ch, 3, padding="same")(x)
x = BatchNormalization()(x)
return Activation("relu")(x)
elif mode == 1:
x = DepthwiseConv2D(3, padding="same")(input)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = DepthwiseConv2D(3, padding="same")(x)
x = BatchNormalization()(x)
return Activation("relu")(x)
elif mode == 2:
x = SeparableConv2D(ch, 3, padding="same")(input)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = SeparableConv2D(ch, 3, padding="same")(x)
x = BatchNormalization()(x)
return Activation("relu")(x)
def create_adjusting_block(input, ch):
x = Conv2D(ch, 1)(input)
x = BatchNormalization()(x)
return Activation("relu")(x)
def create_model(mode):
input = Input((32,32,3))
x = create_adjusting_block(input, 64) # Depthwiseの調整用
x = create_block(mode, x, 64)
x = AveragePooling2D(2)(x)
x = create_adjusting_block(x, 128) # Depthwiseの調整用
x = create_block(mode, x, 128)
x = AveragePooling2D(2)(x)
x = create_adjusting_block(x, 256) # Depthwiseの調整用
x = create_block(mode, x, 256)
x = GlobalAveragePooling2D()(x)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)
return model
class Timer(Callback):
def __init__(self):
self.reset()
def reset(self):
self.time_history = []
self.start_time = time.time()
def on_train_begin(self, logs):
self.reset()
def on_epoch_end(self, epoch, logs):
self.time_history.append(time.time() - self.start_time)
self.start_time = time.time()
def train(mode):
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train / 255.0
X_test = X_test / 255.0
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
model = create_model(mode)
model.compile("adam", "categorical_crossentropy", ["acc"])
timer = Timer()
history = model.fit(X_train, y_train, batch_size=128, validation_data=(X_test, y_test), callbacks=[timer], epochs=100).history
history["times"] = timer.time_history
return history
def train_convs():
result = {}
for i, tag in enumerate(["conv", "depthwise", "separable"]):
result[tag] = train(i)
with open("history_conv_compare.dat", "wb") as fp:
pickle.dump(result, fp)
if __name__ == "__main__":
train_convs()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment