Skip to content

Instantly share code, notes, and snippets.

@tkota0726
Forked from koshian2/optuna_cifar_keras.py
Created April 28, 2019 14:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tkota0726/2b2d4cb94eb6b5cd80d243d9bed51923 to your computer and use it in GitHub Desktop.
Save tkota0726/2b2d4cb94eb6b5cd80d243d9bed51923 to your computer and use it in GitHub Desktop.
Optuna Keras
import tensorflow as tf
from tensorflow.keras.applications import InceptionV3, VGG16, MobileNet
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import History, Callback
import tensorflow.keras.backend as K
from tensorflow.contrib.tpu.python.tpu import keras_support
from keras.utils import to_categorical
from keras.datasets import cifar10
import numpy as np
import os, pickle, glob, zipfile
import optuna
def create_network(network):
assert network in ["inception", "vgg", "mobilenet"]
# 解像度:IncpetionV3=75-, VGG16=32-, MobileNet=128- なので128にあわせる
if network == "inception":
net = InceptionV3(include_top=False, weights="imagenet", input_shape=(128,128,3))
elif network == "vgg":
net = VGG16(include_top=False, weights="imagenet", input_shape=(128,128,3))
elif network == "mobilenet":
net = MobileNet(include_top=False, weights="imagenet", input_shape=(128,128,3)) #128
# 最後の5レイヤーまでをフリーズ
for layer in net.layers[:-5]:
layer.trainable = False
x = GlobalAveragePooling2D()(net.layers[-1].output)
x = Dense(10, activation="softmax")(x)
return Model(net.inputs, x)
class OptunaCallback(Callback):
def __init__(self, trial):
self.trial = trial
def on_epoch_end(self, epoch, logs):
current_val_error = 1.0 - logs["val_acc"]
self.trial.report(current_val_error, step=epoch)
# 打ち切り判定
if self.trial.should_prune(epoch):
raise optuna.structs.TrialPruned()
def generator(X, y, batch_size):
while True:
indices = np.arange(X.shape[0])
np.random.shuffle(indices)
for i in range(X.shape[0]//batch_size):
current_indices = indices[i*batch_size:(i+1)*batch_size]
X_select = X[current_indices]
X_select = X_select.repeat(4, axis=1).repeat(4, axis=2)
X_batch = X_select / 255.0
y_batch = to_categorical(y[current_indices], 10)
yield X_batch, y_batch
def train(network, optimizer, learning_rate, trial):
batch_size = 1024
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
train_gen = generator(X_train, y_train, batch_size)
test_gen = generator(X_test, y_test, batch_size)
model = create_network(network)
if optimizer == "sgd":
model.compile(tf.train.GradientDescentOptimizer(learning_rate), "categorical_crossentropy", ["acc"])
elif optimizer == "momentum":
model.compile(tf.train.MomentumOptimizer(learning_rate, 0.9), "categorical_crossentropy", ["acc"])
elif optimizer == "rmsprop":
model.compile(tf.train.RMSPropOptimizer(learning_rate), "categorical_crossentropy", ["acc"])
elif optimizer == "adam":
model.compile(tf.train.AdamOptimizer(learning_rate), "categorical_crossentropy", ["acc"])
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
if not os.path.exists(network):
os.mkdir(network)
hist = History()
truncate = OptunaCallback(trial)
model.fit_generator(train_gen, X_train.shape[0]//batch_size, callbacks=[hist, truncate],
validation_data=test_gen, validation_steps=X_test.shape[0]//batch_size,
epochs=50)
history = hist.history
return history
def optuna_finding(network):
def objective(trial):
# ハイパーパラメータ(オプティマイザーと学習率を調べる)
optimizer = trial.suggest_categorical("optimizer", ["sgd", "momentum", "rmsprop", "adam"])
learning_rate = trial.suggest_loguniform("learning_rate", 1e-7, 1e0)
K.clear_session()
hist = train(network, optimizer, learning_rate, trial)
return 1.0 - np.max(hist["val_acc"])
study = optuna.create_study()
study.optimize(objective, n_trials=50)
print(study.best_params)
print(study.best_value)
print(study.best_trial)
trial_df = study.trials_dataframe()
trial_df.to_csv("cifar.csv")
if __name__ == "__main__":
optuna_finding("mobilenet")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment