Skip to content

Instantly share code, notes, and snippets.

@ajmazurie
Created December 10, 2015 20:05
Show Gist options
  • Save ajmazurie/7a3ab0d8642ae7723f6a to your computer and use it in GitHub Desktop.
Save ajmazurie/7a3ab0d8642ae7723f6a to your computer and use it in GitHub Desktop.
ssbatch: dead-simple parameter exploration on SLURM-powered computer clusters

ssbatch

The ssbatch command-line utility makes it easy to use a SLURM-powered computer cluster to run a command with different sets of parameters and/or values for these parameters. No need to write your own sbatch submission scripts; ssbatch will do it for you.

Using tags

ssbatch offers a simple tag system to describe the values a parameter should take. You can set a single value, a list of values (including list of files), or a range. ssbatch will then iterate through all the resulting sets of parameters and submit them to SLURM as a job array.

Here are some examples:

$ ssbatch echo "I have {2,5,10} bottles of beer"
$ ssbatch wc {*.*}
$ ssbatch head -n {1:5} {*.txt}
  • the first example will run echo concurrently 3 times, providing it with the string "I have 2 bottles of beer", "I have 5 bottles of beer", or "I have 10 bottles of beer".
  • the second example will run wc on each file in the current directory, concurrently.
  • the third example will run head on each file in the current directory five times, with values 1 to 5 for its parameter -n; again, all concurrently.

ssbatch allows effortless parameter exploration by making it easy to test various combinations of values for a command's parameters. This is especially useful in data science, when complex algorithms are used to analyze datasets. It is not always easy to know how a given parameter of this algorithm will impact the results. With ssbatch it is now trivial to do so.

Reusing tags

Tags can be used either to define a set of values for a parameter, or refer to one. This makes it easy, for example, to name output files in a unique way:

$ ssbatch my_command --sensitivity {v=1:3} --output result_for_{=v}.txt

Here the command my_command will run three times concurrently, with {v=1:3} replaced by values 1 to 3. Since this tag was assigned a name (v) we can use this name elsewhere, for example in the --output option of my_command. The resulting jobs submitted to SLURM will be

$ my_command --sensitivity 1 --output result_for_1.txt
$ my_command --sensitivity 2 --output result_for_2.txt
$ my_command --sensitivity 3 --output result_for_3.txt

Next steps

Currently ssbatch only offers an exhaustive exploration of the sets of parameters you provide. This is great, but it can quickly lead to combinatorial explosion if you start piling them up. We are working on an additional approach where you would randomly sample the set of all possible parameters by setting a budget. For example, once provided the various parameter sets you would tell ssbatch to only explore a thousand of them, randomly selected.

#!/usr/bin/env python
""" ssbatch: a simple way to perform parameter exploration on SLURM
"""
from __future__ import print_function
import argparse
import csv
import datetime
import glob
import io
import itertools
import logging
import logging.config
import operator
import os
import re
import subprocess
import sys
import tempfile
import textwrap
parser = argparse.ArgumentParser(
usage = "ssbatch [ssbatch_option... | command [command_option...]] [-- [sbatch_option...]]")
parser.add_argument("-c", "--command", metavar = "STRING",
dest = "command",
help = "(optional) alternative way to provide a user command, as a string")
parser.add_argument("-n", "--dry-run",
dest = "dry_run", action = "store_true", default = False,
help = """(optional) if set, do nothing but print the list of jobs that
would have been submitted to the SLURM scheduler""")
parser.add_argument("-m", "--max-array-size", type = int, metavar = "INTEGER",
dest = "max_jobs_per_array", default = 1000,
help = """(optional) maximum number of jobs per sbatch array; default:
%(default)d""")
parser.add_argument("-v", "--verbose",
dest = "verbose", action = "store_true", default = False,
help = "(optional) if set, will show additional debug information")
ssbatch_argv, command_argv, sbatch_argv = [], [], []
if (len(sys.argv) == 1):
parser.print_help()
sys.exit(1)
# if the first argument is an option, then all arguments (up
# to the first '--') are expected to be targeted at ssbatch
if (sys.argv[1].startswith('-')):
argv = ssbatch_argv
# if not, all arguments (up to the first '--')
# are for the command to execute through SLURM
else:
argv = command_argv
if (sys.argv.count("--") > 0):
i = sys.argv.index("--")
argv.extend(sys.argv[1:i])
sbatch_argv = sys.argv[i+1:]
else:
argv.extend(sys.argv[1:])
ssbatch_options = parser.parse_args(ssbatch_argv)
logger = logging.getLogger(os.path.basename(__file__))
if (ssbatch_options.verbose):
logging_level = logging.DEBUG
else:
logging_level = logging.INFO
logging.config.dictConfig({
"version": 1,
"disable_existing_loggers": False,
"formatters": {"default": {
"format": "[%(asctime)s] %(levelname)s: %(message)s"
}},
"handlers": {"default": {
"class": "logging.StreamHandler",
"formatter": "default",
}},
"loggers": {"": {
"handlers": ["default"],
"level": logging_level,
"propagate": True
}}
})
def error (msg, is_exception = False):
if (is_exception):
logger.exception(msg)
else:
logger.error(msg)
sys.exit(1)
command = ' '.join(command_argv)
if (ssbatch_options.command is not None):
command = ssbatch_options.command.strip()
if (command == ''):
error("no command provided")
logger.debug("raw command: %s" % command)
sbatch_command = ("sbatch %s" % ' '.join(sbatch_argv)).rstrip()
#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
# Step 1: identification and extraction of the various hyperparameters
_UNESCAPED_BRACKETED_TEXT_PATTERN = re.compile(
r"(?=[^\\\{]?)\{((?:[^\{\}]|\\\{|\\\})*[^\\\}])\}")
# match any text surrounded by unescaped curly
# brackets; see https://regex101.com/r/tJ8aE8/2
_HPARAM_DEFINITION_PATTERN = re.compile("^([a-z0-9]+)=(.+)$", re.I)
_HPARAM_REFERENCE_PATTERN = re.compile("^=([a-z0-9]+)$", re.I)
hparams, hparams_pos = {}, []
# for each hyperparameter found in the original command,
for hparam_match in _UNESCAPED_BRACKETED_TEXT_PATTERN.finditer(command):
hparam = hparam_match.group(1)
hparam_start = hparam_match.start(1)
hparam_end = hparam_match.end(1)
m1 = _HPARAM_DEFINITION_PATTERN.match(hparam)
m2 = _HPARAM_REFERENCE_PATTERN.match(hparam)
# is this a reference to a named hyperparameter?
if (m2 is not None):
hparam_name, hparam_definition = m2.group(1), None
# is this the definition of a named hyperparameter?
elif (m1 is not None):
hparam_name, hparam_definition = m1.groups()
elif ('=' in hparam):
error("invalid hyperparameter '%s'" % hparam)
# then it is the definition of an anonymous hyperparameter
else:
hparam_name = "unnamed_%d" % (len(hparams) + 1)
hparam_definition = hparam
if (hparam_definition is None):
logger.debug("found: reference to hyperparameter '%s'" % hparam_name)
if (not hparam_name in hparams):
hparams[hparam_name] = None
else:
logger.debug("found: hyperparameter '%s' defined as {%s}" % (
hparam_name, hparam_definition))
# we do not allow two ore more hyperparameters with the same name
if (hparam_name in hparams) and (hparams[hparam_name] is not None):
error("duplicated named hyperparameter '%s'" % hparam_name)
hparams[hparam_name] = hparam_definition
hparams_pos.append((hparam_name, hparam_start, hparam_end))
for (hparam_name, hparam_definition) in hparams.iteritems():
if (hparam_definition is None):
error("reference to unknown hyperparameter '%s'" % hparam_name)
# we replace the occurrence of this hyperparameter with a tag
command = command.replace('%', "%%")
for (hparam_name, hparam_start, hparam_end) in reversed(hparams_pos):
command = \
command[:hparam_start-1] + \
"%(" + hparam_name + ")s" + \
command[hparam_end+1:]
logger.debug("processed command: %s" % command)
#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
# Step 2: creation of generator for each hyperparameter
_NUMBER_PATTERN = re.compile(
r"[+-]?\s*(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?")
_ARRAY_SLICE_PATTERN = re.compile(
"^(%(number)s):(%(number)s)(?::(%(number)s))?$" % {
"number": _NUMBER_PATTERN.pattern})
def frange (x, y, step):
if (step < 0):
assert (x >= y)
x, y, step = y, x, -step
while (x <= y):
yield x
x += step
for hparam_name, hparam_definition in hparams.iteritems():
# is this hyperparameter a range of values?
m = _ARRAY_SLICE_PATTERN.match(hparam_definition)
if (m is not None):
elements = []
for element in filter(lambda x: x is not None, m.groups()):
try:
element = int(element)
except:
element = float(element)
elements.append(element)
if (len(elements) == 2):
elements.append(1)
hparams[hparam_name] = frange(*elements)
continue
# is this hyperparameter a glob pattern?
if ('*' in hparam_definition) or ('?' in hparam_definition):
hparams[hparam_name] = glob.glob(hparam_definition)
continue
# is this hyperparameter a list of items?
if (hparam_definition.count(',') > 0):
entry = io.StringIO(unicode(hparam_definition))
elements = list(csv.reader(entry))[0]
hparams[hparam_name] = elements
continue
# if none of the above, the generator is a singleton
hparams[hparam_name] = [hparam_definition]
if (len(hparams) > 0):
hparam_names, hparam_definitions = zip(*sorted(hparams.iteritems()))
else:
hparam_names, hparam_definitions = [], []
def exhaustive_searcher():
for hparam_values in itertools.product(*hparam_definitions):
yield dict(zip(hparam_names, hparam_values))
hparam_values = exhaustive_searcher
#:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
# Step 3: generation of a batch script for SLURM sbatch command
def chunks (iterable, size):
it = iter(iterable)
chunk = tuple(itertools.islice(it, size))
while chunk:
yield chunk
chunk = tuple(itertools.islice(it, size))
timestamp = datetime.datetime.now().isoformat()
n_jobs_total, n_chunks = 0, 0
for psets in chunks(hparam_values(), ssbatch_options.max_jobs_per_array):
batch_array_fn = tempfile.mktemp()
batch_array_fh = open(batch_array_fn, "w")
batch_script_fn = tempfile.mktemp()
batch_script_fh = open(batch_script_fn, "w")
for pset in psets:
print(command % pset, file = batch_array_fh)
if (ssbatch_options.dry_run):
print(command % pset)
n_jobs = len(psets)
n_jobs_total += n_jobs
n_chunks += 1
batch_script_fh.write(textwrap.dedent("""\
#!/bin/bash
#SBATCH --array 1-%(n_jobs)d
#SBATCH --output ssbatch_%(timestamp)s_%%A_%%a.out
#SBATCH --error ssbatch_%(timestamp)s_%%A_%%a.err
JOB_ARRAY="%(batch_array_fn)s"
JOB="$(awk "NR==${SLURM_ARRAY_TASK_ID}" ${JOB_ARRAY})"
eval ${JOB}
""" % locals()))
batch_array_fh.close()
batch_script_fh.close()
sbatch_command_ = sbatch_command + ' ' + batch_script_fn
logger.debug("sbatch command: %s" % sbatch_command_)
if (not ssbatch_options.dry_run):
try:
subprocess.check_call(sbatch_command_, shell = True)
except subprocess.CalledProcessError as exception:
error(str(exception))
msg = "Submitted %d job%s total in %d chunk%s" % (
n_jobs_total, 's' if (n_jobs_total != 1) else '',
n_chunks, 's' if (n_chunks != 1) else '')
if (ssbatch_options.dry_run):
msg += " (dry run)"
print(msg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment