Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active April 7, 2021 03:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/6047e9a989fb1a5be513eb542f33ca99 to your computer and use it in GitHub Desktop.
Save xmodar/6047e9a989fb1a5be513eb542f33ca99 to your computer and use it in GitHub Desktop.
"""Utilities for argparse arguments."""
import os
import sys
from argparse import Namespace
from collections import OrderedDict
from itertools import product, chain
from typing import Union, Dict
__all__ = ['parse_grid']
def parse_ranges(key):
"""Parse range specifier string to a list of integers.
A comma-separated list of this pattern low[-high[:step]]
Example:
'1,3,4-8:2,9' -> [1, 3, 4, 6, 8, 9]
"""
if not key:
return []
if ',' in key:
return list(chain(*map(parse_ranges, key.split(','))))
low, *high = key.split('-')
low, high = int(low), '-'.join(high)
if high:
high, *step = high.split(':')
high, step = int(high), ':'.join(step)
step = int(step) if step else 1
else:
high, step = low, 1
return range(low, high + 1, step)
def grid_from_arg_list(arg_list=None, prefix='--grid'):
"""Extract grid arguments inplace.
Example:
['main.py', '--arg1=1', '--arg2', '2', '--grid=0-2' '--grid_arg3=1,2,3']
returns ({'arg3': ['1', '2', '3']}, [0, 1, 2])
and arg_list becomes ['main.py', '--arg1=1', '--arg2', '2']
"""
if arg_list is None:
arg_list = sys.argv
key = ''
grid = OrderedDict()
i = 0
while i < len(arg_list):
if arg_list[i] == prefix:
argument = arg_list.pop(i)
if '=' in argument:
key = argument.split('=')[1]
else:
key = arg_list.pop(i)
elif arg_list[i].startswith(prefix + '_'):
argument = arg_list.pop(i)
if '=' in argument:
argument, values = argument.split('=')
values = values.split(',')
else:
values = []
while i < len(arg_list):
if arg_list[i].startswith('-'):
break
values.append(arg_list.pop(i))
name = argument[len(prefix) + 1:]
if name not in grid:
grid[name] = []
grid[name].extend(values)
else:
i += 1
return grid, parse_ranges(key)
def dict_to_arg_list(args: Dict[str, Union[str, bool]]):
"""Convert a dict to arg_list."""
arg_list = []
for key, value in args.items():
if isinstance(value, bool):
if value:
arg_list.append(f'--{key}')
else:
arg_list.append(f'--{key}={value}')
return arg_list
def grid_experiments(grid, parser, check=None):
"""Get an iterator over the grid arguments."""
assert isinstance(grid, OrderedDict), 'grid must be ordered'
grid = grid.copy()
for values in grid.values():
if len(values) == 0:
values.extend([True, False])
# get the maximum string length of every argument
length = {
key: 1 + len(key) + max(map(lambda x: len(str(x)), values))
for key, values in grid.items()
}
def get_name(experiment):
"""Parse the name from the experiment's arguments."""
val = experiment if isinstance(experiment, dict) else vars(experiment)
return ' '.join(f'{k}={val[k]}'.ljust(length[k]) for k in grid)
def iterator():
"""Iterate over the grid arguments."""
for values in product(*grid.values()):
exp = Namespace(**dict(zip(grid, values)))
args, _ = parser.parse_known_args(dict_to_arg_list(vars(exp)))
if check is None or check(args):
yield Namespace(**{k: vars(args)[k] for k in grid})
return iterator, get_name
def parse_grid(parser, check=None, arg_list=None, prefix='--grid'):
"""Parse grid arguments.
Example:
parser = ArgumentParser()
parser.add_argument(...)
grid, exps, indices, get_args, get_name = parse_grid(parser)
exclude = ('arg1', 'arg2')
assert all(x not in grid for x in exclude), f'cannot grid on {exclude}'
args = get_args(exps[indices[0]]) # first experiment
# do common setup here
index = f'{{:0{len(str(len(exps)))}d}}'
for i in indices:
args = get_args(exps[i])
print(index.format(i), get_name(args))
# use args here
"""
grid, indices = grid_from_arg_list(arg_list, prefix)
exps, get_name = grid_experiments(grid, parser, check)
exps = list(exps())
if indices:
assert 0 <= min(indices) <= max(indices) < len(exps)
else:
assert not grid or exps, 'no valid grid combination'
if not exps:
exps = [Namespace()]
indices = range(len(exps))
if isinstance(arg_list, Namespace):
arg_list = vars(arg_list)
if arg_list is None:
arg_list = sys.argv[1:]
elif isinstance(arg_list, dict):
arg_list = dict_to_arg_list(arg_list)
def get_args(experiment=None, enforce=None):
"""Parse the experiment's arguments."""
if experiment is None:
experiment = Namespace()
if enforce is None:
enforce = not int(os.environ.get('GRID_OVERRIDE', '0'))
if enforce: # put at the end of arg_list
new_arg_list = arg_list + dict_to_arg_list(vars(experiment))
else:
new_arg_list = dict_to_arg_list(vars(experiment)) + arg_list
return parser.parse_args(new_arg_list)
return grid, exps, indices, get_args, get_name
#!/bin/bash
name=$(basename -s .sh $0) # get the name of the script
method=(
--grid_optimal_sigma
--grid_mode MC NM
--grid_num_samples 1
--max_epochs 10
)
data=(
--grid_dataset CIFAR10 CIFAR100 SVHN CelebA MNIST FreyFace
--batch_size 64
)
model=(
--grid_model MixedVAE
--grid_latent_dim 16 32 64 128 256 512
)
grid=( ${method[@]} ${data[@]} ${model[@]} )
job=( sample_run.sh --info --train --log_dir $name ${grid[@]} )
if [ ! $SUBMIT ]
then
source run.sh ${grid[@]} $@
else
mkdir -p slurm $name
submit="sbatch --job-name nm_vae_$name --export=ALL"
$submit --time=0-40:00 --array=0-143 ${job[@]}
fi
#!/bin/bash --login
#SBATCH --job-name nm_vae
#SBATCH --output slurm/%x.%3a.%A.out
#SBATCH --error slurm/%x.%3a.%A.err
#SBATCH --time 10:00:00
#SBATCH --gres gpu:1
#SBATCH --cpus-per-gpu 6
#SBATCH --mem-per-gpu 32G
#SBATCH --account conf-gpu-2020.11.23
#SBATCH --mail-type FAIL,TIME_LIMIT,TIME_LIMIT_90
#SBATCH --mail-user modar.alfadly@kaust.edu.sa
#SBATCH --exclude gpu208-02
# init conda and activate env (conda should already be in path)
source $(conda info --base)/etc/profile.d/conda.sh
conda activate ${CONDA_ENV:-"nm_vae"}
# if we are under slurm ($SLURM_JOB_NAME is defined)
if [ ! -z $SLURM_ARRAY_TASK_ID ]
then
task="--grid $SLURM_ARRAY_TASK_ID"
else
task=""
fi
python -m nm_vae $task "$@"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment