Skip to content

Instantly share code, notes, and snippets.

@hvy
Last active December 7, 2020 04:45
Show Gist options
  • Save hvy/86c7d33ebda94f5ad74169ba39a3b52d to your computer and use it in GitHub Desktop.
Save hvy/86c7d33ebda94f5ad74169ba39a3b52d to your computer and use it in GitHub Desktop.
Optuna example that optimizes a simple quadratic function in parallel using joblib with arbitrary arguments to the objective function.
"""
Optuna example that optimizes a simple quadratic function in parallel using `joblib` allowing
arbitrary arguments to the objective function.
Run the example as follows.
$ python quadratic_joblib_simple.py
If you need to rerun the example and thus delete previous studies, you can use the Optuna CLI.
$ optuna delete-study --study-name joblib-quadratic --storage "sqlite:///example.db"
See also
https://optuna.readthedocs.io/en/latest/faq.html#how-to-define-objective-functions-that-have-own-arguments
"""
from joblib import Parallel, delayed
import optuna
def print_study(study):
print('Number of finished trials: ', len(study.trials))
print('Best trial:')
trial = study.best_trial
print(' Value: ', trial.value)
print(' Params: ')
for key, value in trial.params.items():
print(' {}: {}'.format(key, value))
def optimize(n_trials, min_x, max_x):
study = optuna.load_study(study_name='joblib-quadratic', storage='sqlite:///example.db')
# You can either use a lambda (as shown here) or define a class that holds the arguments and
# implements `__call__`.
study.optimize(lambda trial: objective(trial, min_x, max_x), n_trials=n_trials)
# An objective function does not only take the trial, but also additional arguments.
def objective(trial, min_x, max_x):
x = trial.suggest_uniform('x', min_x, max_x)
return (x - 2) ** 2
if __name__ == '__main__':
study = optuna.create_study(study_name='joblib-quadratic', storage='sqlite:///example.db')
# `Study.optimize` arguments.
n_trials = 10
# Arbitrary arguments to the objective function.
min_x = -100
max_x = 100
# `joblib` arguments.
n_iterables = 3
r = Parallel(n_jobs=-1)(
[delayed(optimize)(n_trials, min_x, max_x) for _ in range(n_iterables)])
assert len(study.trials) == n_trials * n_iterables
print_study(study)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment