Skip to content

Instantly share code, notes, and snippets.

@donkirkby
Last active February 18, 2023 15:20
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save donkirkby/be2d583d0eb53daf82f7 to your computer and use it in GitHub Desktop.
Save donkirkby/be2d583d0eb53daf82f7 to your computer and use it in GitHub Desktop.
Worker processes with mpi4py
#! /usr/bin/env python
import argparse
import csv
from mpi4py import MPI
import logging
import time
def parseOptions(comm_world):
parser = argparse.ArgumentParser(
description="Test of coordinating tasks through MPI")
parser.add_argument("definitions", type=argparse.FileType(mode='rU'))
return parser.parse_args()
def polling_receive(comm, source):
# Set this to 0 for maximum responsiveness, but that will peg CPU to 100%
sleep_seconds = 0.1
if sleep_seconds > 0:
while not comm.Iprobe(source=source):
time.sleep(sleep_seconds)
result = comm.recv(source=source)
return result
def record_result(waiting_workers, ready_workers):
""" Receive a result from any worker process.
Put the source process, along with any other processes that were waiting
for it, into ready_workers.
"""
source_host, worker_rank, result = polling_receive(MPI.COMM_WORLD,
source=MPI.ANY_SOURCE)
if result is not None:
logging.info('Received {!r}'.format(result))
host_workers = ready_workers.get(source_host)
if host_workers is None:
ready_workers[source_host] = host_workers = set()
host_workers.add(worker_rank)
host_workers.update(waiting_workers.pop(worker_rank, []))
def main():
logging.basicConfig(level=logging.INFO,
format='%(asctime)s[%(levelname)s]%(message)s')
comm = MPI.COMM_WORLD
process_rank = comm.Get_rank()
process_count = comm.Get_size()
process_host = MPI.Get_processor_name()
if process_rank == 0:
args = parseOptions(comm)
waiting_workers = {} # {active_rank: [waiting_rank]}
ready_workers = {} # {host_name: set([rank])}
with args.definitions:
reader = csv.DictReader(args.definitions)
for row in reader:
thread_count = int(row['threads'])
chosen_workers = None
while chosen_workers is None:
for ranks in ready_workers.itervalues():
if len(ranks) >= thread_count:
chosen_workers = ranks
break
# TODO: error if thread_count can never be satisfied.
if chosen_workers is None:
record_result(waiting_workers, ready_workers)
worker_rank = chosen_workers.pop()
if thread_count > 1:
waiting_workers[worker_rank] = [
chosen_workers.pop()
for _ in range(thread_count - 1)]
comm.send(row, dest=worker_rank)
active_workers = set(range(1, process_count))
while active_workers:
record_result(waiting_workers, ready_workers)
for ranks in ready_workers.itervalues():
for worker_rank in ranks:
comm.send(None, dest=worker_rank)
active_workers.remove(worker_rank)
ranks.clear()
logging.info("Done on root.")
else:
result = None
while True:
comm.send((process_host, process_rank, result), dest=0)
request = polling_receive(comm, source=0)
if request is None:
logging.info('Done on rank %d, host %s.',
process_rank,
process_host)
break
logging.info('Start task %s on rank %d.',
request['name'],
process_rank)
time.sleep(int(request['duration']))
logging.info('Finish task %s on rank %d.',
request['name'],
process_rank)
result = {'name': request['name'], 'hash': hash(request['name'])}
main()
import subprocess
import traceback
import sys
def check_mpi_version(prefix):
try:
return subprocess.check_output(prefix + 'mpirun -V',
shell=True,
stderr=subprocess.STDOUT)
except:
etype, value, _tb = sys.exc_info()
return traceback.format_exception_only(etype, value)
def main():
prefix = ''
expected_version = 'Open MPI'
version = check_mpi_version(prefix)
if not expected_version in version:
prefix = 'module load openmpi/gnu && '
version = check_mpi_version(prefix)
if not expected_version in version:
sys.exit("Couldn't find Open MPI:\n{}".format(version))
mapping_args = ['mpirun',
'-np',
'4',
'--hostfile',
'hostfile',
'multitasking.py',
'taskdefs/multithread.csv']
mapping_command = prefix + ' '.join(mapping_args)
subprocess.check_call(mapping_command, shell=True)
main()
name duration threads
a 10 1
b 15 2
c 20 1
d 10 1
e 10 1
name duration threads
a 10 1
b 15 1
c 20 1
@donkirkby
Copy link
Author

This is similar to jbornschein's mpi4py task pull example, with the addition of allocating each task a different number of processors. If a task requires 3 processors, one of the MPI processes will run it in three threads while two of the other MPI processes on the same host will block until it finishes. The blocked processes ensure that the process with three threads doesn't compete with another process for processors. It checks the host names so you don't end up with all the active processes on one host and all the blocked processes on another.

See the jbornschein example for how to use a Status object to get source and tag information.

See the OpenMPI documentation for more information about using hostfile to specify where to run the worker processes. To run them all locally, just create hostfile with a single line:

localhost

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment