Skip to content

Instantly share code, notes, and snippets.

@jreniel
Forked from donkirkby/multitasking.py
Created February 18, 2023 15:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jreniel/13b794f0918bbe89991b7c670aecf716 to your computer and use it in GitHub Desktop.
Save jreniel/13b794f0918bbe89991b7c670aecf716 to your computer and use it in GitHub Desktop.
Worker processes with mpi4py
#! /usr/bin/env python
import argparse
import csv
from mpi4py import MPI
import logging
import time
def parseOptions(comm_world):
parser = argparse.ArgumentParser(
description="Test of coordinating tasks through MPI")
parser.add_argument("definitions", type=argparse.FileType(mode='rU'))
return parser.parse_args()
def polling_receive(comm, source):
# Set this to 0 for maximum responsiveness, but that will peg CPU to 100%
sleep_seconds = 0.1
if sleep_seconds > 0:
while not comm.Iprobe(source=source):
time.sleep(sleep_seconds)
result = comm.recv(source=source)
return result
def record_result(waiting_workers, ready_workers):
""" Receive a result from any worker process.
Put the source process, along with any other processes that were waiting
for it, into ready_workers.
"""
source_host, worker_rank, result = polling_receive(MPI.COMM_WORLD,
source=MPI.ANY_SOURCE)
if result is not None:
logging.info('Received {!r}'.format(result))
host_workers = ready_workers.get(source_host)
if host_workers is None:
ready_workers[source_host] = host_workers = set()
host_workers.add(worker_rank)
host_workers.update(waiting_workers.pop(worker_rank, []))
def main():
logging.basicConfig(level=logging.INFO,
format='%(asctime)s[%(levelname)s]%(message)s')
comm = MPI.COMM_WORLD
process_rank = comm.Get_rank()
process_count = comm.Get_size()
process_host = MPI.Get_processor_name()
if process_rank == 0:
args = parseOptions(comm)
waiting_workers = {} # {active_rank: [waiting_rank]}
ready_workers = {} # {host_name: set([rank])}
with args.definitions:
reader = csv.DictReader(args.definitions)
for row in reader:
thread_count = int(row['threads'])
chosen_workers = None
while chosen_workers is None:
for ranks in ready_workers.itervalues():
if len(ranks) >= thread_count:
chosen_workers = ranks
break
# TODO: error if thread_count can never be satisfied.
if chosen_workers is None:
record_result(waiting_workers, ready_workers)
worker_rank = chosen_workers.pop()
if thread_count > 1:
waiting_workers[worker_rank] = [
chosen_workers.pop()
for _ in range(thread_count - 1)]
comm.send(row, dest=worker_rank)
active_workers = set(range(1, process_count))
while active_workers:
record_result(waiting_workers, ready_workers)
for ranks in ready_workers.itervalues():
for worker_rank in ranks:
comm.send(None, dest=worker_rank)
active_workers.remove(worker_rank)
ranks.clear()
logging.info("Done on root.")
else:
result = None
while True:
comm.send((process_host, process_rank, result), dest=0)
request = polling_receive(comm, source=0)
if request is None:
logging.info('Done on rank %d, host %s.',
process_rank,
process_host)
break
logging.info('Start task %s on rank %d.',
request['name'],
process_rank)
time.sleep(int(request['duration']))
logging.info('Finish task %s on rank %d.',
request['name'],
process_rank)
result = {'name': request['name'], 'hash': hash(request['name'])}
main()
import subprocess
import traceback
import sys
def check_mpi_version(prefix):
try:
return subprocess.check_output(prefix + 'mpirun -V',
shell=True,
stderr=subprocess.STDOUT)
except:
etype, value, _tb = sys.exc_info()
return traceback.format_exception_only(etype, value)
def main():
prefix = ''
expected_version = 'Open MPI'
version = check_mpi_version(prefix)
if not expected_version in version:
prefix = 'module load openmpi/gnu && '
version = check_mpi_version(prefix)
if not expected_version in version:
sys.exit("Couldn't find Open MPI:\n{}".format(version))
mapping_args = ['mpirun',
'-np',
'4',
'--hostfile',
'hostfile',
'multitasking.py',
'taskdefs/multithread.csv']
mapping_command = prefix + ' '.join(mapping_args)
subprocess.check_call(mapping_command, shell=True)
main()
name duration threads
a 10 1
b 15 2
c 20 1
d 10 1
e 10 1
name duration threads
a 10 1
b 15 1
c 20 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment