Created
October 20, 2023 10:39
-
-
Save vtalpaert/7eed87a9614659f837bb6eb0cade001e to your computer and use it in GitHub Desktop.
This gist is an overcomplicated way to reproduce `study.optimize(objective, n_trials=n_trials, n_jobs=more_than_one)` with CUDA libraries that do not go well with a background thread context. In this case, we use multiprocessing. Useful when you encounter errors such as CUDA_ERROR_ILLEGAL_ADDRESS or CUBLAS_STATUS_NOT_INITIALIZED
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import multiprocessing | |
import optuna | |
# objective function, returns a score for example | |
def objective(trial): | |
return 1 | |
def optimize(study_name, n_trials, sleep): | |
print( | |
f"Starting study {study_name} for {n_trials} trials, waiting {sleep}") | |
# You need a sleep here, otherwise all processes will open the journal at the same time | |
# and generate the same trial parameters | |
time.sleep(sleep) | |
storage = optuna.storages.JournalStorage( | |
optuna.storages.JournalFileStorage(f"optuna/{study_name}.log") | |
) | |
# load, do not re-create | |
study = optuna.load_study(study_name=study_name, storage=storage) | |
# single job study makes my CUDA library (it was pykeops) happy | |
study.optimize( | |
objective, n_trials=n_trials, | |
n_jobs=1) | |
if __name__ == "__main__": | |
n_trials = 1000 | |
n_jobs = 20 | |
study_name = "my-gist" | |
# don't create more jobs than cpu threads | |
_n_jobs = min(n_jobs, os.cpu_count() - 1) | |
# wait between jobs | |
_wait = 5 # [s] | |
# create storage | |
os.makedirs("optuna", exist_ok=True) | |
storage = optuna.storages.JournalStorage( | |
optuna.storages.JournalFileStorage(f"optuna/{study_name}.log") | |
) | |
# create study, adapt load_if_exists | |
study = optuna.create_study( | |
direction="maximize", | |
study_name=study_name, | |
storage=storage, | |
load_if_exists=True, | |
) | |
multiprocessing.set_start_method('spawn') | |
with multiprocessing.Pool(_n_jobs) as pool: | |
pool.starmap(optimize, [ | |
(study_name, int(n_trials / _n_jobs), job_id * _wait) | |
for job_id in range(n_jobs)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See also https://pytorch.org/docs/stable/multiprocessing.html#spawning-subprocesses for a pytorch specific and simpler one-liner