Last active
April 7, 2021 03:43
-
-
Save xmodar/6047e9a989fb1a5be513eb542f33ca99 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Utilities for argparse arguments.""" | |
import os | |
import sys | |
from argparse import Namespace | |
from collections import OrderedDict | |
from itertools import product, chain | |
from typing import Union, Dict | |
__all__ = ['parse_grid'] | |
def parse_ranges(key): | |
"""Parse range specifier string to a list of integers. | |
A comma-separated list of this pattern low[-high[:step]] | |
Example: | |
'1,3,4-8:2,9' -> [1, 3, 4, 6, 8, 9] | |
""" | |
if not key: | |
return [] | |
if ',' in key: | |
return list(chain(*map(parse_ranges, key.split(',')))) | |
low, *high = key.split('-') | |
low, high = int(low), '-'.join(high) | |
if high: | |
high, *step = high.split(':') | |
high, step = int(high), ':'.join(step) | |
step = int(step) if step else 1 | |
else: | |
high, step = low, 1 | |
return range(low, high + 1, step) | |
def grid_from_arg_list(arg_list=None, prefix='--grid'): | |
"""Extract grid arguments inplace. | |
Example: | |
['main.py', '--arg1=1', '--arg2', '2', '--grid=0-2' '--grid_arg3=1,2,3'] | |
returns ({'arg3': ['1', '2', '3']}, [0, 1, 2]) | |
and arg_list becomes ['main.py', '--arg1=1', '--arg2', '2'] | |
""" | |
if arg_list is None: | |
arg_list = sys.argv | |
key = '' | |
grid = OrderedDict() | |
i = 0 | |
while i < len(arg_list): | |
if arg_list[i] == prefix: | |
argument = arg_list.pop(i) | |
if '=' in argument: | |
key = argument.split('=')[1] | |
else: | |
key = arg_list.pop(i) | |
elif arg_list[i].startswith(prefix + '_'): | |
argument = arg_list.pop(i) | |
if '=' in argument: | |
argument, values = argument.split('=') | |
values = values.split(',') | |
else: | |
values = [] | |
while i < len(arg_list): | |
if arg_list[i].startswith('-'): | |
break | |
values.append(arg_list.pop(i)) | |
name = argument[len(prefix) + 1:] | |
if name not in grid: | |
grid[name] = [] | |
grid[name].extend(values) | |
else: | |
i += 1 | |
return grid, parse_ranges(key) | |
def dict_to_arg_list(args: Dict[str, Union[str, bool]]): | |
"""Convert a dict to arg_list.""" | |
arg_list = [] | |
for key, value in args.items(): | |
if isinstance(value, bool): | |
if value: | |
arg_list.append(f'--{key}') | |
else: | |
arg_list.append(f'--{key}={value}') | |
return arg_list | |
def grid_experiments(grid, parser, check=None): | |
"""Get an iterator over the grid arguments.""" | |
assert isinstance(grid, OrderedDict), 'grid must be ordered' | |
grid = grid.copy() | |
for values in grid.values(): | |
if len(values) == 0: | |
values.extend([True, False]) | |
# get the maximum string length of every argument | |
length = { | |
key: 1 + len(key) + max(map(lambda x: len(str(x)), values)) | |
for key, values in grid.items() | |
} | |
def get_name(experiment): | |
"""Parse the name from the experiment's arguments.""" | |
val = experiment if isinstance(experiment, dict) else vars(experiment) | |
return ' '.join(f'{k}={val[k]}'.ljust(length[k]) for k in grid) | |
def iterator(): | |
"""Iterate over the grid arguments.""" | |
for values in product(*grid.values()): | |
exp = Namespace(**dict(zip(grid, values))) | |
args, _ = parser.parse_known_args(dict_to_arg_list(vars(exp))) | |
if check is None or check(args): | |
yield Namespace(**{k: vars(args)[k] for k in grid}) | |
return iterator, get_name | |
def parse_grid(parser, check=None, arg_list=None, prefix='--grid'): | |
"""Parse grid arguments. | |
Example: | |
parser = ArgumentParser() | |
parser.add_argument(...) | |
grid, exps, indices, get_args, get_name = parse_grid(parser) | |
exclude = ('arg1', 'arg2') | |
assert all(x not in grid for x in exclude), f'cannot grid on {exclude}' | |
args = get_args(exps[indices[0]]) # first experiment | |
# do common setup here | |
index = f'{{:0{len(str(len(exps)))}d}}' | |
for i in indices: | |
args = get_args(exps[i]) | |
print(index.format(i), get_name(args)) | |
# use args here | |
""" | |
grid, indices = grid_from_arg_list(arg_list, prefix) | |
exps, get_name = grid_experiments(grid, parser, check) | |
exps = list(exps()) | |
if indices: | |
assert 0 <= min(indices) <= max(indices) < len(exps) | |
else: | |
assert not grid or exps, 'no valid grid combination' | |
if not exps: | |
exps = [Namespace()] | |
indices = range(len(exps)) | |
if isinstance(arg_list, Namespace): | |
arg_list = vars(arg_list) | |
if arg_list is None: | |
arg_list = sys.argv[1:] | |
elif isinstance(arg_list, dict): | |
arg_list = dict_to_arg_list(arg_list) | |
def get_args(experiment=None, enforce=None): | |
"""Parse the experiment's arguments.""" | |
if experiment is None: | |
experiment = Namespace() | |
if enforce is None: | |
enforce = not int(os.environ.get('GRID_OVERRIDE', '0')) | |
if enforce: # put at the end of arg_list | |
new_arg_list = arg_list + dict_to_arg_list(vars(experiment)) | |
else: | |
new_arg_list = dict_to_arg_list(vars(experiment)) + arg_list | |
return parser.parse_args(new_arg_list) | |
return grid, exps, indices, get_args, get_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
name=$(basename -s .sh $0) # get the name of the script | |
method=( | |
--grid_optimal_sigma | |
--grid_mode MC NM | |
--grid_num_samples 1 | |
--max_epochs 10 | |
) | |
data=( | |
--grid_dataset CIFAR10 CIFAR100 SVHN CelebA MNIST FreyFace | |
--batch_size 64 | |
) | |
model=( | |
--grid_model MixedVAE | |
--grid_latent_dim 16 32 64 128 256 512 | |
) | |
grid=( ${method[@]} ${data[@]} ${model[@]} ) | |
job=( sample_run.sh --info --train --log_dir $name ${grid[@]} ) | |
if [ ! $SUBMIT ] | |
then | |
source run.sh ${grid[@]} $@ | |
else | |
mkdir -p slurm $name | |
submit="sbatch --job-name nm_vae_$name --export=ALL" | |
$submit --time=0-40:00 --array=0-143 ${job[@]} | |
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash --login | |
#SBATCH --job-name nm_vae | |
#SBATCH --output slurm/%x.%3a.%A.out | |
#SBATCH --error slurm/%x.%3a.%A.err | |
#SBATCH --time 10:00:00 | |
#SBATCH --gres gpu:1 | |
#SBATCH --cpus-per-gpu 6 | |
#SBATCH --mem-per-gpu 32G | |
#SBATCH --account conf-gpu-2020.11.23 | |
#SBATCH --mail-type FAIL,TIME_LIMIT,TIME_LIMIT_90 | |
#SBATCH --mail-user modar.alfadly@kaust.edu.sa | |
#SBATCH --exclude gpu208-02 | |
# init conda and activate env (conda should already be in path) | |
source $(conda info --base)/etc/profile.d/conda.sh | |
conda activate ${CONDA_ENV:-"nm_vae"} | |
# if we are under slurm ($SLURM_JOB_NAME is defined) | |
if [ ! -z $SLURM_ARRAY_TASK_ID ] | |
then | |
task="--grid $SLURM_ARRAY_TASK_ID" | |
else | |
task="" | |
fi | |
python -m nm_vae $task "$@" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment