Skip to content

Instantly share code, notes, and snippets.

@keisuke-umezawa
Created April 9, 2023 07:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keisuke-umezawa/d26f2ad40f52d4436265e6f88b1df036 to your computer and use it in GitHub Desktop.
Save keisuke-umezawa/d26f2ad40f52d4436265e6f88b1df036 to your computer and use it in GitHub Desktop.
HITL optimization with optuna
import os
import textwrap
import time
from typing import NoReturn
import optuna
from PIL import Image
from optuna.trial import TrialState
from optuna_dashboard import ObjectiveChoiceWidget, save_note
from optuna_dashboard import register_objective_form_widgets
from optuna_dashboard import set_objective_names
from optuna_dashboard.artifact import get_artifact_path, upload_artifact
from optuna_dashboard.artifact.file_system import FileSystemBackend
url = "sqlite:///db.sqlite3"
storage = optuna.storages.RDBStorage(url=url)
artifact_path = os.path.join(os.path.dirname(__file__), "artifact")
tmp_path = os.path.join(os.path.dirname(__file__), "tmp")
artifact_backend = FileSystemBackend(base_path=artifact_path)
def suggest_and_generate_image(study: optuna.Study) -> None:
# Ask new parameters
trial = study.ask()
r = trial.suggest_int("r", 0, 255)
g = trial.suggest_int("g", 0, 255)
b = trial.suggest_int("b", 0, 255)
# Generate image
image_path = f"tmp/sample-{trial.number}.png"
image = Image.new("RGB", (320, 240), color=(r, g, b))
image.save(image_path)
# Upload Artifact
artifact_id = upload_artifact(artifact_backend, trial, image_path)
artifact_path = get_artifact_path(trial, artifact_id)
# Save Note
note = textwrap.dedent(
f"""\
## Trial {trial.number}
![generated-image]({artifact_path})
"""
)
save_note(trial, note)
def start_preferential_optimization() -> NoReturn:
# Create Study
seed = 42
sampler = optuna.samplers.TPESampler(constant_liar=True, seed=seed)
study = optuna.create_study(
study_name="Preferential Optimization",
storage=storage,
sampler=sampler,
load_if_exists=True,
)
orig_storage = study._storage
if isinstance(orig_storage, optuna.storages._cached_storage._CachedStorage):
study._storage = orig_storage._backend
set_objective_names(study, ["Looks like sunset color?"])
register_objective_form_widgets(
study,
widgets=[
ObjectiveChoiceWidget(
choices=["Good 👍", "So-so👌", "Bad 👎"],
values=[-1, 0, 1],
description="Please input your score!",
),
],
)
# Start Preferential Optimization
n_batch = 8
while True:
running_trials = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,))
if len(running_trials) >= n_batch:
print("sleep")
time.sleep(1)
continue
suggest_and_generate_image(study)
def main() -> None:
if not os.path.exists(artifact_path):
os.mkdir(artifact_path)
if not os.path.exists(tmp_path):
os.mkdir(tmp_path)
# Run optimize loop
start_preferential_optimization()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment