Skip to content

Instantly share code, notes, and snippets.

@c-bata
Created May 17, 2023 09:36
Show Gist options
  • Save c-bata/0e739d661e21a5bc4c2ddf2141bf6a9e to your computer and use it in GitHub Desktop.
Save c-bata/0e739d661e21a5bc4c2ddf2141bf6a9e to your computer and use it in GitHub Desktop.
import sys
import numpy as np
from kurobako import problem
from kurobako.problem import Problem
from typing import List
from typing import Optional
class RastriginEvaluator(problem.Evaluator):
def __init__(self, params: List[Optional[float]]):
self.n = len(params)
self.x = np.array(params, dtype=float)
self._current_step = 0
def evaluate(self, next_step: int) -> List[float]:
self._current_step = 1
value = 10 * self.n + np.sum(self.x**2 - 10 * np.cos(2 * np.pi * self.x))
return [value]
def current_step(self) -> int:
return self._current_step
class RastriginProblem(problem.Problem):
def create_evaluator(
self, params: List[Optional[float]]
) -> Optional[problem.Evaluator]:
return RastriginEvaluator(params)
class RastriginProblemFactory(problem.ProblemFactory):
def __init__(self, dim):
self.dim = dim
def create_problem(self, seed: int) -> Problem:
return RastriginProblem()
def specification(self) -> problem.ProblemSpec:
params = [
problem.Var(f"x{i+1}", problem.ContinuousRange(-5.12, 5.12))
for i in range(self.dim)
]
return problem.ProblemSpec(
name=f"Rastrigin (dim={self.dim})",
params=params,
values=[problem.Var("Rastrigin")],
)
if __name__ == "__main__":
dim = int(sys.argv[1]) if len(sys.argv) == 2 else 2
runner = problem.ProblemRunner(RastriginProblemFactory(dim))
runner.run()
import argparse
import os
import subprocess
def run(args: argparse.Namespace) -> None:
kurobako_cmd = os.path.join(args.path_to_kurobako, "kurobako")
subprocess.run(f"{kurobako_cmd} --version", shell=True)
os.makedirs(args.out_dir, exist_ok=True)
study_json_fn = os.path.join(args.out_dir, "studies.json")
solvers_filename = os.path.join(args.out_dir, "solvers.json")
problems_filename = os.path.join(args.out_dir, "problems.json")
# Ensure all files are empty.
for filename in [study_json_fn, solvers_filename, problems_filename]:
with open(filename, "w"):
pass
# Create Rastrigin-2D bench problem.
cmd = (
f'{kurobako_cmd} problem command python problem_rastrigin.py 2 | tee -a {problems_filename}'
)
subprocess.run(cmd, shell=True)
# Create Optuna solvers
for name, sampler, sampler_kwargs in [
("random", "RandomSampler", r"{}"),
("vanilla-cma-es", "CmaEsSampler", r"{}"),
("bipop-cma-es", "CmaEsSampler", r"{\"restart_strategy\":\"bipop\"}"),
("ipop-cma-es", "CmaEsSampler", r"{\"restart_strategy\":\"ipop\"}"),
]:
cmd = (
f"{kurobako_cmd} solver --name {name} optuna --loglevel debug "
f"--sampler {sampler} --sampler-kwargs {sampler_kwargs} "
"--pruner NopPruner --pruner-kwargs {} "
f"| tee -a {solvers_filename}"
)
subprocess.run(cmd, shell=True)
# Create study.
cmd = (
f"{kurobako_cmd} studies --budget {args.budget} "
f"--solvers $(cat {solvers_filename}) --problems $(cat {problems_filename}) "
f"--repeats {args.n_runs} --seed {args.seed} --concurrency {args.n_concurrency} "
f"> {study_json_fn}"
)
subprocess.run(cmd, shell=True, check=True)
result_filename = os.path.join(args.out_dir, "results.json")
cmd = (
f"cat {study_json_fn} | {kurobako_cmd} run --parallelism {args.n_jobs} "
f"> {result_filename}"
)
subprocess.run(cmd, shell=True)
report_filename = os.path.join(args.out_dir, "report.md")
cmd = f"cat {result_filename} | {kurobako_cmd} report > {report_filename}"
subprocess.run(cmd, shell=True)
cmd = (
f"cat {result_filename} | docker run -v $(pwd)/{args.out_dir}/images:/images/ "
f"--rm -i sile/kurobako plot curve --errorbar --xmin 10"
)
subprocess.run(cmd, shell=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--path-to-kurobako", type=str, default="")
parser.add_argument("--budget", type=int, default=5000)
parser.add_argument("--n-runs", type=int, default=10)
parser.add_argument("--n-jobs", type=int, default=10)
parser.add_argument("--n-concurrency", type=int, default=1)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--out-dir", type=str, default="tmp/benchmark_report")
args = parser.parse_args()
run(args)
@c-bata
Copy link
Author

c-bata commented May 17, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment