Skip to content

Instantly share code, notes, and snippets.

@rueberger
Last active June 3, 2020 07:07
Show Gist options
  • Save rueberger/d16da35098c09f9369275d2a83e8b138 to your computer and use it in GitHub Desktop.
Save rueberger/d16da35098c09f9369275d2a83e8b138 to your computer and use it in GitHub Desktop.
A quick way to distribute embarrassingly parallel things with ipyparallel on multiple gpus (or other things)
WS_N_GPUS = {
'turagas-ws1': 2,
'turagas-ws2': 2,
'turagas-ws3': 2,
'turagas-ws4': 2,
'c04u01': 8,
'c04u07': 8,
'c04u12': 8,
'c04u17': 8,
}
def gpu_job_runner(job_fnc, job_args, ipp_profile='ssh_gpu_py2', log_name=None,
status_interval=600, allow_engine_overlap=True):
""" Distribute a set of jobs across an IPyParallel 'GPU cluster'
Requires that cluster has already been started with `ipcluster start --profile={}`.forat(ipp_profile)
Checks on the jobs every status_interval seconds, logging status.
Args:
job_fnc: the function to distribute
must accept `device` as a kwarg, as this function is wrapped so that
device is bound within the engine namespace
returned values are ignored
job_args: list of args passed to job_fnc - list
ipp_profile: profile of GPU IPyParallel profile - str
log_name: (optional) name for log
status_interval: (optional) the amount of time, in seconds, to wait before querying the AsyncResult
object for the status of the jobs
"""
from ipyparallel import Client, interactive, RemoteError, Reference
import inspect
# setup logging
log_path = os.path.expanduser("~/logs/deepspike/job_runner.log")
log_name = log_name or 'job_runner'
logger = setup_logging(log_name, log_path)
# TODO: this isn't strictly necessary
try:
# check that job_fnc accepts a device kwarg
args = inspect.getargspec(job_fnc)[0]
assert 'device' in args
except AssertionError:
logger.critical("job_fnc does not except device kwarg. Halting.")
client = Client(profile=ipp_profile)
logger.info("Succesfully initialized client on %s with %s engines", ipp_profile, len(client))
# assign each engine to a GPU
engines_per_host = {}
device_assignments = []
engine_hosts = client[:].apply(socket.gethostname).get()
for host in engine_hosts:
if host in engines_per_host:
device_assignments.append('/gpu:{}'.format(engines_per_host[host]))
engines_per_host[host] += 1
else:
device_assignments.append('/gpu:0')
engines_per_host[host] = 1
logger.info("Engines per host: \n")
if not allow_engine_overlap:
try:
# check that we haven't over-provisioned GPUs
for host, n_engines in engines_per_host.iteritems():
logger.info("%s: %s", host, n_engines)
assert n_engines <= WS_N_GPUS[host]
except AssertionError:
logger.critical("Host has more engines than GPUs. Halting.")
while True:
try:
# NOTE: could also be accomplished with process environment variables
# broadcast device assignments and job_fnc
for engine_id, engine_device in enumerate(device_assignments):
print("Pushing to engine {}: device: {}".format(engine_id, engine_device))
client[engine_id].push({'device': engine_device,
'job_fnc': job_fnc}, block=True)
for engine_id, (host, assigned_device) in enumerate(zip(engine_hosts, device_assignments)):
remote_device = client[engine_id].pull('device').get()
logger.info("Engine %s: host = %s; device = %s, remote device = %s",
engine_id, host, assigned_device, remote_device)
break
except RemoteError as remote_err:
logger.warn("Caught remote error: %s. Sleeping for 10s before retry", remote_err)
time.sleep(10)
logger.info("Dispatching jobs: %s", job_args)
# dispatch jobs
async_result = client[:].map(job_fnc, job_args, [Reference('device')] * len(job_args))
start_time = time.time()
while not async_result.ready():
time.sleep(status_interval)
n_finished = async_result.progress
n_jobs = len(job_args)
wall_time = start_time - time.time()
logger.info("%s seconds elapsed. %s of %s jobs finished",
wall_time, n_finished, n_jobs)
logger.info("All jobs finished in %s seconds!", async_result.wall_time)
def setup_logging(log_name, log_path):
""" Sets up module level logging
"""
# define module level logger
logger = logging.getLogger(log_name)
logger.setLevel(logging.DEBUG)
log_path = os.path.expanduser(log_path)
# define file handler for module
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
# create formatter and add to handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
# add handler to logger
logger.addHandler(fh)
return logger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment