Created
December 21, 2018 15:05
-
-
Save koshian2/b932eadadcb1934abfbcd5f049c2515d to your computer and use it in GitHub Desktop.
Conv2D, DepthwiseConv2D, SeparableConv2D compare on CIFAR-10 by Optuna
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, SeparableConv2D, BatchNormalization, Activation, Input, GlobalAveragePooling2D, Dense, AveragePooling2D | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.callbacks import Callback, History | |
from tensorflow.contrib.tpu.python.tpu import keras_support | |
import tensorflow.keras.backend as K | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
import time, optuna, os | |
def create_conv_layers(mode, input, ch, k=3): | |
assert mode in ["conv", "depthwise", "separable"] | |
if mode == "conv": | |
x = Conv2D(ch, k, padding="same")(input) | |
elif mode == "depthwise": | |
x = DepthwiseConv2D(k, padding="same")(input) | |
elif mode == "separable": | |
x = SeparableConv2D(ch, k, padding="same")(input) | |
x = BatchNormalization()(x) | |
return Activation("relu")(x) | |
def create_single_module(modes, input, ch): | |
x = input | |
for m in modes: | |
x = create_conv_layers(m, x, ch) | |
return x | |
def create_models(modes): | |
# 必ずチャンネル数を増やすために1x1convを入れる | |
input = Input((32,32,3)) | |
x = create_conv_layers("conv", input, 64, 1) | |
x = create_single_module(modes, x, 64) | |
x = AveragePooling2D(2)(x) | |
x = create_conv_layers("conv", x, 128, 1) | |
x = create_single_module(modes, x, 128) | |
x = AveragePooling2D(2)(x) | |
x = create_conv_layers("conv", x, 256, 1) | |
x = create_single_module(modes, x, 256) | |
x = GlobalAveragePooling2D()(x) | |
x = Dense(10, activation="softmax")(x) | |
return Model(input, x) | |
class OptunaCallback(Callback): | |
def __init__(self, trial): | |
self.trial = trial | |
def on_epoch_end(self, epoch, logs): | |
current_val_error = 1.0 - logs["val_acc"] | |
self.trial.report(current_val_error, step=epoch) | |
if self.trial.should_prune(epoch): | |
print("trial pruned at epoch", epoch) | |
raise optuna.structs.TrialPruned() | |
def train_optuna(): | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
X_train = X_train / 255.0 | |
X_test = X_test / 255.0 | |
y_train = to_categorical(y_train) | |
y_test = to_categorical(y_test) | |
def objectives(trial): | |
K.clear_session() | |
layer1 = trial.suggest_categorical("layer1", ["conv", "depthwise", "separable"]) | |
layer2 = trial.suggest_categorical("layer2", ["conv", "depthwise", "separable"]) | |
layer3 = trial.suggest_categorical("layer3", ["conv", "depthwise", "separable"]) | |
layers = [layer1, layer2, layer3] | |
model = create_models(layers) | |
n_params = model.count_params() | |
model.compile(tf.train.AdamOptimizer(), "categorical_crossentropy", ["acc"]) | |
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) | |
prauner = OptunaCallback(trial) | |
hist = History() | |
model.fit(X_train, y_train, batch_size=1024, validation_data=(X_test, y_test), epochs=100, verbose=0, | |
callbacks=[prauner, hist]) | |
history = hist.history | |
trial.set_user_attr("val_acc", max(history["val_acc"])) | |
trial.set_user_attr("n_params", n_params) | |
return 1.0 - max(history["val_acc"]) + (n_params / 250000 * 0.05) # 1Mのモデルは20%のペナルティー | |
study = optuna.create_study() | |
study.optimize(objectives, n_trials=50) | |
print(study.best_params) | |
print(study.best_value) | |
print(study.best_trial) | |
df = study.trials_dataframe() | |
df.to_csv("conv_layers_optuna_params.csv") | |
if __name__ == "__main__": | |
train_optuna() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment