-
-
Save tkota0726/2b2d4cb94eb6b5cd80d243d9bed51923 to your computer and use it in GitHub Desktop.
Optuna Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.applications import InceptionV3, VGG16, MobileNet | |
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.callbacks import History, Callback | |
import tensorflow.keras.backend as K | |
from tensorflow.contrib.tpu.python.tpu import keras_support | |
from keras.utils import to_categorical | |
from keras.datasets import cifar10 | |
import numpy as np | |
import os, pickle, glob, zipfile | |
import optuna | |
def create_network(network): | |
assert network in ["inception", "vgg", "mobilenet"] | |
# 解像度:IncpetionV3=75-, VGG16=32-, MobileNet=128- なので128にあわせる | |
if network == "inception": | |
net = InceptionV3(include_top=False, weights="imagenet", input_shape=(128,128,3)) | |
elif network == "vgg": | |
net = VGG16(include_top=False, weights="imagenet", input_shape=(128,128,3)) | |
elif network == "mobilenet": | |
net = MobileNet(include_top=False, weights="imagenet", input_shape=(128,128,3)) #128 | |
# 最後の5レイヤーまでをフリーズ | |
for layer in net.layers[:-5]: | |
layer.trainable = False | |
x = GlobalAveragePooling2D()(net.layers[-1].output) | |
x = Dense(10, activation="softmax")(x) | |
return Model(net.inputs, x) | |
class OptunaCallback(Callback): | |
def __init__(self, trial): | |
self.trial = trial | |
def on_epoch_end(self, epoch, logs): | |
current_val_error = 1.0 - logs["val_acc"] | |
self.trial.report(current_val_error, step=epoch) | |
# 打ち切り判定 | |
if self.trial.should_prune(epoch): | |
raise optuna.structs.TrialPruned() | |
def generator(X, y, batch_size): | |
while True: | |
indices = np.arange(X.shape[0]) | |
np.random.shuffle(indices) | |
for i in range(X.shape[0]//batch_size): | |
current_indices = indices[i*batch_size:(i+1)*batch_size] | |
X_select = X[current_indices] | |
X_select = X_select.repeat(4, axis=1).repeat(4, axis=2) | |
X_batch = X_select / 255.0 | |
y_batch = to_categorical(y[current_indices], 10) | |
yield X_batch, y_batch | |
def train(network, optimizer, learning_rate, trial): | |
batch_size = 1024 | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
train_gen = generator(X_train, y_train, batch_size) | |
test_gen = generator(X_test, y_test, batch_size) | |
model = create_network(network) | |
if optimizer == "sgd": | |
model.compile(tf.train.GradientDescentOptimizer(learning_rate), "categorical_crossentropy", ["acc"]) | |
elif optimizer == "momentum": | |
model.compile(tf.train.MomentumOptimizer(learning_rate, 0.9), "categorical_crossentropy", ["acc"]) | |
elif optimizer == "rmsprop": | |
model.compile(tf.train.RMSPropOptimizer(learning_rate), "categorical_crossentropy", ["acc"]) | |
elif optimizer == "adam": | |
model.compile(tf.train.AdamOptimizer(learning_rate), "categorical_crossentropy", ["acc"]) | |
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) | |
if not os.path.exists(network): | |
os.mkdir(network) | |
hist = History() | |
truncate = OptunaCallback(trial) | |
model.fit_generator(train_gen, X_train.shape[0]//batch_size, callbacks=[hist, truncate], | |
validation_data=test_gen, validation_steps=X_test.shape[0]//batch_size, | |
epochs=50) | |
history = hist.history | |
return history | |
def optuna_finding(network): | |
def objective(trial): | |
# ハイパーパラメータ(オプティマイザーと学習率を調べる) | |
optimizer = trial.suggest_categorical("optimizer", ["sgd", "momentum", "rmsprop", "adam"]) | |
learning_rate = trial.suggest_loguniform("learning_rate", 1e-7, 1e0) | |
K.clear_session() | |
hist = train(network, optimizer, learning_rate, trial) | |
return 1.0 - np.max(hist["val_acc"]) | |
study = optuna.create_study() | |
study.optimize(objective, n_trials=50) | |
print(study.best_params) | |
print(study.best_value) | |
print(study.best_trial) | |
trial_df = study.trials_dataframe() | |
trial_df.to_csv("cifar.csv") | |
if __name__ == "__main__": | |
optuna_finding("mobilenet") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment