Skip to content

Instantly share code, notes, and snippets.

@willwhitney
Last active September 27, 2020 09:48
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save willwhitney/e1509c86522896c6930d2fe9ea49a522 to your computer and use it in GitHub Desktop.
Save willwhitney/e1509c86522896c6930d2fe9ea49a522 to your computer and use it in GitHub Desktop.
Script for running grids of experiments on slurm
#----------------------------------------------
# things to change:
# code_dir (the full path of the directory that contains your source dir)
# true_source_dir (change it from TD3 to whatever your source dir is called)
# job_source_dir (someplace to throw a duplicate of the source dir for this job)
#----------------------------------------------
import os
import sys
import itertools
dry_run = '--dry-run' in sys.argv
clear = '--clear' in sys.argv
if not os.path.exists("slurm_logs"):
os.makedirs("slurm_logs")
if not os.path.exists("slurm_scripts"):
os.makedirs("slurm_scripts")
code_dir = '/private/home/willwhitney/code'
basename = "PFnew_start_traj1"
grids = [
# raw
{
"main_file": ['main'],
"env_name": [
'Pusher-v2',
'Striker-v2',
'Thrower-v2',
],
# "start_timesteps": [0],
"max_timesteps": [1e7],
"eval_freq": [5e3],
"render_freq": [1e5],
"seed": list(range(8)),
},
# learned embedding
{
"main_file": ['main_embedded'],
"env_name": [
'Pusher-v2',
'Striker-v2',
'Thrower-v2',
],
"decoder": [
# "white_qvel_traj8_z7",
"white_qvel_traj1_z7",
],
# "start_timesteps": [0],
"max_timesteps": [1e7],
"eval_freq": [5e3],
"render_freq": [1e5],
"seed": list(range(8)),
},
]
jobs = []
for grid in grids:
individual_options = [[{key: value} for value in values]
for key, values in grid.items()]
product_options = list(itertools.product(*individual_options))
jobs += [{k: v for d in option_set for k, v in d.items()}
for option_set in product_options]
if dry_run:
print("NOT starting {} jobs:".format(len(jobs)))
else:
print("Starting {} jobs:".format(len(jobs)))
all_keys = set().union(*[g.keys() for g in grids])
merged = {k: set() for k in all_keys}
for grid in grids:
for key in all_keys:
grid_key_value = grid[key] if key in grid else ["<<NONE>>"]
merged[key] = merged[key].union(grid_key_value)
varying_keys = {key for key in merged if len(merged[key]) > 1}
excluded_flags = {'main_file'}
for job in jobs:
jobname = basename
flagstring = ""
for flag in job:
# construct the string of arguments to be passed to the script
if not flag in excluded_flags:
if isinstance(job[flag], bool):
if job[flag]:
flagstring = flagstring + " --" + flag
else:
print("WARNING: Excluding 'False' flag " + flag)
else:
flagstring = flagstring + " --" + flag + " " + str(job[flag])
# construct the job's name
if flag in varying_keys:
jobname = jobname + "_" + flag + str(job[flag])
flagstring = flagstring + " --name " + jobname
slurm_script_path = 'slurm_scripts/' + jobname + '.slurm'
slurm_script_dir = os.path.dirname(slurm_script_path)
os.makedirs(slurm_script_dir, exist_ok=True)
slurm_log_dir = 'slurm_logs/' + jobname
os.makedirs(os.path.dirname(slurm_log_dir), exist_ok=True)
true_source_dir = code_dir + '/TD3'
job_source_dir = code_dir + '/TD3-clones/' + jobname
try:
os.makedirs(job_source_dir)
os.system('cp -R ./* ' + job_source_dir)
except FileExistsError:
# with the 'clear' flag, we're starting fresh
# overwrite the code that's already here
if clear:
print("Overwriting existing files.")
os.system('cp -R ./* ' + job_source_dir)
jobcommand = "python {}/{}.py{}".format(job_source_dir, job['main_file'], flagstring)
job_start_command = "sbatch " + slurm_script_path
# jobcommand += " --restart-command '{}'".format(job_start_command)
print(jobcommand)
with open(slurm_script_path, 'w') as slurmfile:
slurmfile.write("#!/bin/bash\n")
slurmfile.write("#SBATCH --job-name" + "=" + jobname + "\n")
slurmfile.write("#SBATCH --open-mode=append\n")
slurmfile.write("#SBATCH --output=slurm_logs/" +
jobname + ".out\n")
slurmfile.write("#SBATCH --error=slurm_logs/" + jobname + ".err\n")
slurmfile.write("#SBATCH --export=ALL\n")
slurmfile.write("#SBATCH --signal=USR1@600\n")
slurmfile.write("#SBATCH --time=1-00\n")
slurmfile.write("#SBATCH -N 1\n")
slurmfile.write("#SBATCH --mem=32gb\n")
slurmfile.write("#SBATCH -c 4\n")
slurmfile.write("#SBATCH --gres=gpu:1\n")
slurmfile.write("cd " + true_source_dir + '\n')
slurmfile.write("srun " + jobcommand)
slurmfile.write("\n")
if not dry_run:
os.system(job_start_command + " &")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment