Last active
December 8, 2021 18:46
-
-
Save amysteier/f5eefade7e4fe74f5633b944aa2b7cbe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
def create_study(study_name, dataset, trial_job_cnt, trials_per_job, api_key, storage, sampler): | |
study = optuna.create_study(study_name=study_name,storage=storage, sampler=sampler, direction="maximize") | |
# Tell Optuna to start with our default config settings | |
study.enqueue_trial( | |
{ | |
"vocab_size": config['models'][0]['synthetics']['params']['vocab_size'], | |
"reset_states": config['models'][0]['synthetics']['params']['reset_states'], | |
"rnn_units": config['models'][0]['synthetics']['params']['rnn_units'], | |
"learning_rate": config['models'][0]['synthetics']['params']['learning_rate'], | |
"gen_temp": config['models'][0]['synthetics']['params']['gen_temp'], | |
"dropout_rate": config['models'][0]['synthetics']['params']['dropout_rate'], | |
} | |
) | |
# We will run a total of "trial_cnt" trials with "trial_job_cnt" number of processes running in parallel | |
trial_cnt = str(trials_per_job) | |
for i in range(trial_job_cnt): | |
mytrial = subprocess.Popen(["python", "Optuna_Trials.py", study_name, trial_cnt, dataset, api_key, storage]) | |
return study | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment