Skip to content

Instantly share code, notes, and snippets.

@adrn
Last active March 13, 2016 07:45
Show Gist options
  • Save adrn/33bec127338b9774d0f8 to your computer and use it in GitHub Desktop.
Save adrn/33bec127338b9774d0f8 to your computer and use it in GitHub Desktop.
If you write your code to use map(), this will trivially let you run it in parallel via multiprocessing or MPI.

Examples:

% time python example_script.py
real	0m10.205s
user	0m0.080s
sys	0m0.035s

% time python example_script.py --threads=2
real	0m5.465s
user	0m0.096s
sys	0m0.049s

% time mpiexec -n 4 python example_script.py --mpi
real	0m5.131s
user	0m1.664s
sys	0m4.744s
# coding: utf-8
""" Example of how I use the get_pool() with ArgParse. """
from __future__ import division, print_function
__author__ = "adrn <adrn@astro.columbia.edu>"
import sys
import numpy as np
import time
from pools import get_pool
def f(x):
time.sleep(1.)
return x**2
def main(pool):
return pool.map(f, np.random.random(size=10))
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(description="")
# threading
parser.add_argument("--mpi", dest="mpi", default=False, action="store_true",
help="Run with MPI.")
parser.add_argument("--threads", dest="threads", default=None, type=int,
help="Number of multiprocessing threads to run on.")
args = parser.parse_args()
pool = get_pool(mpi=args.mpi, threads=args.threads)
main(pool)
pool.close()
sys.exit(0)
# coding: utf-8
""" Some serial operations can be made trivially parallel by structuring your code to
use the built-in `map()` function. This function returns a pool object that --
based on the arguments -- will let you map in serial, multiprocessing, or MPI.
"""
from __future__ import division, print_function
__author__ = "adrn <adrn@astro.columbia.edu>"
# Standard library
import sys
import multiprocessing
__all__ = ['get_pool']
class SerialPool(object):
def close(self):
return
def map(self, *args, **kwargs):
return map(*args, **kwargs)
def get_pool(mpi=False, threads=None):
""" Always returns a pool object with a `map()` method. By default,
returns a `SerialPool()` -- `SerialPool.map()` just calls the built-in
Python function `map()`. If `mpi=True`, will attempt to import the
`MPIPool` implementation from `emcee`. If `threads` is set to a
number > 1, it will return a Python multiprocessing pool.
Parameters
----------
mpi : bool (optional)
Use MPI or not. If specified, ignores the threads kwarg.
threads : int (optional)
If mpi is False and threads is specified, use a Python
multiprocessing pool with the specified number of threads.
"""
if mpi:
from mpipool import MPIPool
# Initialize the MPI pool
pool = MPIPool()
# Make sure the thread we're running on is the master
if not pool.is_master():
pool.wait()
sys.exit(0)
elif threads > 1:
pool = multiprocessing.Pool(threads)
else:
pool = SerialPool()
return pool
@keflavich
Copy link

I did a somewhat similar thing here, but a slightly different approach:
https://github.com/keflavich/FITS_tools/blob/master/FITS_tools/cube_regrid.py#L334

You'd think something as fundamental as parallel mapping would be easier to access from core libraries, no?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment