Skip to content

Instantly share code, notes, and snippets.

@aphearin
Created January 19, 2023 20:25
Show Gist options
  • Save aphearin/fe749e757c8b8dbd3e27d5c9f439eab9 to your computer and use it in GitHub Desktop.
Save aphearin/fe749e757c8b8dbd3e27d5c9f439eab9 to your computer and use it in GitHub Desktop.
Example pattern for a simple script parallelized with mpi4py using JAX-generated randoms
"""mpiexec -n 2 python demo_mpi4py.py
"""
from mpi4py import MPI
import argparse
from jax import random as jran
import numpy as np
OUTPAT = "results_rank_{0}.dat"
def compute_something(z):
return z
if __name__ == "__main__":
comm = MPI.COMM_WORLD
rank, nranks = comm.Get_rank(), comm.Get_size()
parser = argparse.ArgumentParser()
parser.add_argument("-outpat", help="Pattern of output files", default=OUTPAT)
parser.add_argument(
"-n_iter_per_rank", help="Number of iterations per rank", default=5
)
parser.add_argument(
"-npts_fake_data", help="Number of points of fake data per iteration", default=3
)
args = parser.parse_args()
# Get an initial random number seed
ran_key_for_rank = jran.PRNGKey(rank)
# Each rank loops over the requested number of iterations
collector_for_rank = []
for i in range(args.n_iter_per_rank):
# Create random key for this iteration
ran_key_for_rank, i_key = jran.split(ran_key_for_rank, 2)
# generate some random data for this iteration
random_data_i = jran.uniform(
i_key, minval=0, maxval=1, shape=(args.npts_fake_data,)
)
# Compute results for this iteration
res_i = compute_something(random_data_i)
collector_for_rank.append(res_i)
# Pack all the results into a matrix
final_results_for_rank = np.array(collector_for_rank)
# Save the results into a file for this rank
outname = args.outpat.format(rank)
np.savetxt(outname, final_results_for_rank)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment