Skip to content

Instantly share code, notes, and snippets.

@belltailjp
Last active September 8, 2021 09:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save belltailjp/e7a33d985ade8cefcb939212d39002fe to your computer and use it in GitHub Desktop.
Save belltailjp/e7a33d985ade8cefcb939212d39002fe to your computer and use it in GitHub Desktop.
ppe-config based optuna
import optuna
# https://optuna.org/
def objective(trial):
x = trial.suggest_uniform('x', -10, 10)
return (x - 2) ** 2
def main():
study = optuna.create_study()
study.optimize(objective, n_trials=100)
print(study.best_params)
if __name__ == '__main__':
main()
import yaml
import optuna
import pytorch_pfn_extras as ppe
class Net:
def __init__(self, value):
self.value = value
def get_value(self):
return (self.value - 2) ** 2
cfg_yml = """
net:
type: Net
value: 0.4
"""
types = {
'Net': Net,
}
def run_train(net):
# Run a training and return the final accuracy of this run
return net.get_value()
def main():
cfg = yaml.safe_load(cfg_yml)
cfg = ppe.config.Config(cfg, types)
net = cfg['/net']
acc = run_train(net)
print(acc)
if __name__ == "__main__":
main()
import yaml
import optuna
import pytorch_pfn_extras as ppe
import pytorch_pfn_extras.config_types as config_types
class Net:
def __init__(self, value):
self.value = value
def get_value(self):
return (self.value - 2) ** 2
cfg_yml = """
net:
type: Net
value:
type: optuna_suggest_float
name: x
low: -10.0
high: 10.0
"""
types = {
'Net': Net,
}
def run_train(net):
# Run a training and return the final accuracy of this run
acc = net.get_value()
return acc
def main():
# These params cannot be included in the config yaml
n_trials = 100
sampler_type = optuna.samplers.TPESampler
sampler_kwargs = dict()
def objective(trial):
cfg = yaml.safe_load(cfg_yml)
cfg = ppe.config.Config(cfg, {**types, **config_types.optuna_types(trial)})
net = cfg['/net']
return run_train(net)
if True:
# Use SQLite (it creates (or loads if already exists) my_study.db in the cwd)
storage = 'sqlite:///my_study.db'
else:
# Use memory
storage = None
study = optuna.create_study(sampler=sampler_type(**sampler_kwargs), storage=storage)
study.optimize(objective, n_trials=n_trials)
print(study.best_params)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment