|
#!/usr/bin/env python |
|
""" ssbatch: a simple way to perform parameter exploration on SLURM |
|
""" |
|
|
|
from __future__ import print_function |
|
|
|
import argparse |
|
import csv |
|
import datetime |
|
import glob |
|
import io |
|
import itertools |
|
import logging |
|
import logging.config |
|
import operator |
|
import os |
|
import re |
|
import subprocess |
|
import sys |
|
import tempfile |
|
import textwrap |
|
|
|
parser = argparse.ArgumentParser( |
|
usage = "ssbatch [ssbatch_option... | command [command_option...]] [-- [sbatch_option...]]") |
|
|
|
parser.add_argument("-c", "--command", metavar = "STRING", |
|
dest = "command", |
|
help = "(optional) alternative way to provide a user command, as a string") |
|
|
|
parser.add_argument("-n", "--dry-run", |
|
dest = "dry_run", action = "store_true", default = False, |
|
help = """(optional) if set, do nothing but print the list of jobs that |
|
would have been submitted to the SLURM scheduler""") |
|
|
|
parser.add_argument("-m", "--max-array-size", type = int, metavar = "INTEGER", |
|
dest = "max_jobs_per_array", default = 1000, |
|
help = """(optional) maximum number of jobs per sbatch array; default: |
|
%(default)d""") |
|
|
|
parser.add_argument("-v", "--verbose", |
|
dest = "verbose", action = "store_true", default = False, |
|
help = "(optional) if set, will show additional debug information") |
|
|
|
ssbatch_argv, command_argv, sbatch_argv = [], [], [] |
|
|
|
if (len(sys.argv) == 1): |
|
parser.print_help() |
|
sys.exit(1) |
|
|
|
# if the first argument is an option, then all arguments (up |
|
# to the first '--') are expected to be targeted at ssbatch |
|
if (sys.argv[1].startswith('-')): |
|
argv = ssbatch_argv |
|
|
|
# if not, all arguments (up to the first '--') |
|
# are for the command to execute through SLURM |
|
else: |
|
argv = command_argv |
|
|
|
if (sys.argv.count("--") > 0): |
|
i = sys.argv.index("--") |
|
argv.extend(sys.argv[1:i]) |
|
sbatch_argv = sys.argv[i+1:] |
|
else: |
|
argv.extend(sys.argv[1:]) |
|
|
|
ssbatch_options = parser.parse_args(ssbatch_argv) |
|
|
|
logger = logging.getLogger(os.path.basename(__file__)) |
|
|
|
if (ssbatch_options.verbose): |
|
logging_level = logging.DEBUG |
|
else: |
|
logging_level = logging.INFO |
|
|
|
logging.config.dictConfig({ |
|
"version": 1, |
|
"disable_existing_loggers": False, |
|
"formatters": {"default": { |
|
"format": "[%(asctime)s] %(levelname)s: %(message)s" |
|
}}, |
|
"handlers": {"default": { |
|
"class": "logging.StreamHandler", |
|
"formatter": "default", |
|
}}, |
|
"loggers": {"": { |
|
"handlers": ["default"], |
|
"level": logging_level, |
|
"propagate": True |
|
}} |
|
}) |
|
|
|
def error (msg, is_exception = False): |
|
if (is_exception): |
|
logger.exception(msg) |
|
else: |
|
logger.error(msg) |
|
sys.exit(1) |
|
|
|
command = ' '.join(command_argv) |
|
if (ssbatch_options.command is not None): |
|
command = ssbatch_options.command.strip() |
|
if (command == ''): |
|
error("no command provided") |
|
|
|
logger.debug("raw command: %s" % command) |
|
|
|
sbatch_command = ("sbatch %s" % ' '.join(sbatch_argv)).rstrip() |
|
|
|
#::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: |
|
# Step 1: identification and extraction of the various hyperparameters |
|
|
|
_UNESCAPED_BRACKETED_TEXT_PATTERN = re.compile( |
|
r"(?=[^\\\{]?)\{((?:[^\{\}]|\\\{|\\\})*[^\\\}])\}") |
|
# match any text surrounded by unescaped curly |
|
# brackets; see https://regex101.com/r/tJ8aE8/2 |
|
|
|
_HPARAM_DEFINITION_PATTERN = re.compile("^([a-z0-9]+)=(.+)$", re.I) |
|
_HPARAM_REFERENCE_PATTERN = re.compile("^=([a-z0-9]+)$", re.I) |
|
|
|
hparams, hparams_pos = {}, [] |
|
|
|
# for each hyperparameter found in the original command, |
|
for hparam_match in _UNESCAPED_BRACKETED_TEXT_PATTERN.finditer(command): |
|
hparam = hparam_match.group(1) |
|
hparam_start = hparam_match.start(1) |
|
hparam_end = hparam_match.end(1) |
|
|
|
m1 = _HPARAM_DEFINITION_PATTERN.match(hparam) |
|
m2 = _HPARAM_REFERENCE_PATTERN.match(hparam) |
|
|
|
# is this a reference to a named hyperparameter? |
|
if (m2 is not None): |
|
hparam_name, hparam_definition = m2.group(1), None |
|
|
|
# is this the definition of a named hyperparameter? |
|
elif (m1 is not None): |
|
hparam_name, hparam_definition = m1.groups() |
|
|
|
elif ('=' in hparam): |
|
error("invalid hyperparameter '%s'" % hparam) |
|
|
|
# then it is the definition of an anonymous hyperparameter |
|
else: |
|
hparam_name = "unnamed_%d" % (len(hparams) + 1) |
|
hparam_definition = hparam |
|
|
|
if (hparam_definition is None): |
|
logger.debug("found: reference to hyperparameter '%s'" % hparam_name) |
|
|
|
if (not hparam_name in hparams): |
|
hparams[hparam_name] = None |
|
else: |
|
logger.debug("found: hyperparameter '%s' defined as {%s}" % ( |
|
hparam_name, hparam_definition)) |
|
|
|
# we do not allow two ore more hyperparameters with the same name |
|
if (hparam_name in hparams) and (hparams[hparam_name] is not None): |
|
error("duplicated named hyperparameter '%s'" % hparam_name) |
|
|
|
hparams[hparam_name] = hparam_definition |
|
|
|
hparams_pos.append((hparam_name, hparam_start, hparam_end)) |
|
|
|
for (hparam_name, hparam_definition) in hparams.iteritems(): |
|
if (hparam_definition is None): |
|
error("reference to unknown hyperparameter '%s'" % hparam_name) |
|
|
|
# we replace the occurrence of this hyperparameter with a tag |
|
command = command.replace('%', "%%") |
|
for (hparam_name, hparam_start, hparam_end) in reversed(hparams_pos): |
|
command = \ |
|
command[:hparam_start-1] + \ |
|
"%(" + hparam_name + ")s" + \ |
|
command[hparam_end+1:] |
|
|
|
logger.debug("processed command: %s" % command) |
|
|
|
#::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: |
|
# Step 2: creation of generator for each hyperparameter |
|
|
|
_NUMBER_PATTERN = re.compile( |
|
r"[+-]?\s*(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?") |
|
|
|
_ARRAY_SLICE_PATTERN = re.compile( |
|
"^(%(number)s):(%(number)s)(?::(%(number)s))?$" % { |
|
"number": _NUMBER_PATTERN.pattern}) |
|
|
|
def frange (x, y, step): |
|
if (step < 0): |
|
assert (x >= y) |
|
x, y, step = y, x, -step |
|
while (x <= y): |
|
yield x |
|
x += step |
|
|
|
for hparam_name, hparam_definition in hparams.iteritems(): |
|
# is this hyperparameter a range of values? |
|
m = _ARRAY_SLICE_PATTERN.match(hparam_definition) |
|
if (m is not None): |
|
elements = [] |
|
for element in filter(lambda x: x is not None, m.groups()): |
|
try: |
|
element = int(element) |
|
except: |
|
element = float(element) |
|
elements.append(element) |
|
|
|
if (len(elements) == 2): |
|
elements.append(1) |
|
|
|
hparams[hparam_name] = frange(*elements) |
|
continue |
|
|
|
# is this hyperparameter a glob pattern? |
|
if ('*' in hparam_definition) or ('?' in hparam_definition): |
|
hparams[hparam_name] = glob.glob(hparam_definition) |
|
continue |
|
|
|
# is this hyperparameter a list of items? |
|
if (hparam_definition.count(',') > 0): |
|
entry = io.StringIO(unicode(hparam_definition)) |
|
elements = list(csv.reader(entry))[0] |
|
hparams[hparam_name] = elements |
|
continue |
|
|
|
# if none of the above, the generator is a singleton |
|
hparams[hparam_name] = [hparam_definition] |
|
|
|
if (len(hparams) > 0): |
|
hparam_names, hparam_definitions = zip(*sorted(hparams.iteritems())) |
|
else: |
|
hparam_names, hparam_definitions = [], [] |
|
|
|
def exhaustive_searcher(): |
|
for hparam_values in itertools.product(*hparam_definitions): |
|
yield dict(zip(hparam_names, hparam_values)) |
|
|
|
hparam_values = exhaustive_searcher |
|
|
|
#::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: |
|
# Step 3: generation of a batch script for SLURM sbatch command |
|
|
|
def chunks (iterable, size): |
|
it = iter(iterable) |
|
chunk = tuple(itertools.islice(it, size)) |
|
while chunk: |
|
yield chunk |
|
chunk = tuple(itertools.islice(it, size)) |
|
|
|
timestamp = datetime.datetime.now().isoformat() |
|
|
|
n_jobs_total, n_chunks = 0, 0 |
|
for psets in chunks(hparam_values(), ssbatch_options.max_jobs_per_array): |
|
batch_array_fn = tempfile.mktemp() |
|
batch_array_fh = open(batch_array_fn, "w") |
|
|
|
batch_script_fn = tempfile.mktemp() |
|
batch_script_fh = open(batch_script_fn, "w") |
|
|
|
for pset in psets: |
|
print(command % pset, file = batch_array_fh) |
|
if (ssbatch_options.dry_run): |
|
print(command % pset) |
|
|
|
n_jobs = len(psets) |
|
n_jobs_total += n_jobs |
|
n_chunks += 1 |
|
|
|
batch_script_fh.write(textwrap.dedent("""\ |
|
#!/bin/bash |
|
#SBATCH --array 1-%(n_jobs)d |
|
#SBATCH --output ssbatch_%(timestamp)s_%%A_%%a.out |
|
#SBATCH --error ssbatch_%(timestamp)s_%%A_%%a.err |
|
|
|
JOB_ARRAY="%(batch_array_fn)s" |
|
JOB="$(awk "NR==${SLURM_ARRAY_TASK_ID}" ${JOB_ARRAY})" |
|
|
|
eval ${JOB} |
|
""" % locals())) |
|
|
|
batch_array_fh.close() |
|
batch_script_fh.close() |
|
|
|
sbatch_command_ = sbatch_command + ' ' + batch_script_fn |
|
logger.debug("sbatch command: %s" % sbatch_command_) |
|
|
|
if (not ssbatch_options.dry_run): |
|
try: |
|
subprocess.check_call(sbatch_command_, shell = True) |
|
|
|
except subprocess.CalledProcessError as exception: |
|
error(str(exception)) |
|
|
|
msg = "Submitted %d job%s total in %d chunk%s" % ( |
|
n_jobs_total, 's' if (n_jobs_total != 1) else '', |
|
n_chunks, 's' if (n_chunks != 1) else '') |
|
|
|
if (ssbatch_options.dry_run): |
|
msg += " (dry run)" |
|
|
|
print(msg) |