Skip to content

Instantly share code, notes, and snippets.

@kstoneriv3
Created April 9, 2023 22:20
Show Gist options
  • Save kstoneriv3/1e877830e535d9ebea0f9d55ed50734c to your computer and use it in GitHub Desktop.
Save kstoneriv3/1e877830e535d9ebea0f9d55ed50734c to your computer and use it in GitHub Desktop.
Benchmark script of batched sampling with botorch
from concurrent.futures import ProcessPoolExecutor, wait
import fire
from matplotlib import pyplot as plt
import numpy as np
import optuna
import os
import time
import torch
N_WORKERS = 10
N_TRIALS = 40
def func(trial):
X = [trial.suggest_int(f"x_{i}", -5, 5) for i in range(10)]
time.sleep(60 * np.random.rand())
return sum(x ** 2 for x in X)
def get_study(consider_running_trials=False, seed=None):
# torch.cuda.set_per_process_memory_fraction(fraction=0.8 / N_WORKERS, device='cuda:0')
storage = optuna.storages.JournalStorage(
optuna.storages.JournalFileStorage("./journal.log"),
)
pruner = optuna.pruners.NopPruner()
sampler = optuna.integration.BoTorchSampler(
consider_running_trials=consider_running_trials, seed=seed, # device="cuda:0"
)
study = optuna.create_study(
study_name="study_0", storage=storage, sampler=sampler, load_if_exists=True
)
return study
def optimize(i, consider_running_trials):
study = get_study(consider_running_trials, i)
study.optimize(func, n_trials=1)
def run(consider_running_trials):
try:
os.remove("./journal.log")
except OSError:
pass
study = get_study()
with ProcessPoolExecutor(max_workers=N_WORKERS) as executor:
futures = executor.map(optimize, range(N_TRIALS), [consider_running_trials] * N_TRIALS)
list(futures) # wait here
return study
def benchmark(n_iter=10):
torch.multiprocessing.set_start_method('spawn')
results = {
consider_running_trials: [run(consider_running_trials).best_value for i in range(n_iter)]
for consider_running_trials in (False, True)
}
print(results)
print("mean:")
print({k: np.mean(v) for k, v in results.items()})
print("std:")
print({k: np.std(v) / np.sqrt(n_iter) for k, v in results.items()})
if __name__ == "__main__":
# fire.Fire(main)
benchmark(20)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment