Skip to content

Instantly share code, notes, and snippets.

@toshihikoyanase
Created August 7, 2019 05:34
Show Gist options
  • Save toshihikoyanase/9df25c7fcea140bf0ec24fc31a717b20 to your computer and use it in GitHub Desktop.
Save toshihikoyanase/9df25c7fcea140bf0ec24fc31a717b20 to your computer and use it in GitHub Desktop.
An example to prune trials when they reports NaN values as intermediate values.
import math
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
import optuna
class NaNValuePruner(optuna.pruners.BasePruner):
def __init__(self, pruner=None):
# type: (Optional[optuna.pruners.BasePruner]) -> None
self.pruner = pruner
def prune(self, storage, study_id, trial_id, step):
# type: (optuna.storages.BaseStorage, int, int, int) -> bool
intermediate_values = storage.get_trial(trial_id).intermediate_values
if len(intermediate_values) == 0:
return False
if math.isnan(intermediate_values[step]):
return True
if self.pruner is not None:
return self.pruner.prune(storage, study_id, trial_id, step)
return False
# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, test_x, train_y, test_y = \
sklearn.model_selection.train_test_split(iris.data, iris.target, test_size=0.25)
alpha = trial.suggest_loguniform('alpha', 1e-5, 1e-1)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(10):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(test_x, test_y)
import random
if random.random() < 0.1:
intermediate_value = float('nan')
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.structs.TrialPruned()
return 1.0 - clf.score(test_x, test_y)
if __name__ == '__main__':
study = optuna.create_study(pruner=NaNValuePruner())
study.optimize(objective, n_trials=20)
pruned_trials = [t for t in study.trials if t.state == optuna.structs.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.structs.TrialState.COMPLETE]
print(study.trials_dataframe())
print('Study statistics: ')
print(' Number of finished trials: ', len(study.trials))
print(' Number of pruned trials: ', len(pruned_trials))
print(' Number of complete trials: ', len(complete_trials))
print('Best trial:')
trial = study.best_trial
print(' Value: ', trial.value)
print(' Params: ')
for key, value in trial.params.items():
print(' {}: {}'.format(key, value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment