Skip to content

Instantly share code, notes, and snippets.

@adrn
Created February 20, 2024 22:09
Show Gist options
  • Save adrn/d983bc02dcf2b57ef28163213d38cc5c to your computer and use it in GitHub Desktop.
Save adrn/d983bc02dcf2b57ef28163213d38cc5c to your computer and use it in GitHub Desktop.
MPI pool example
import pathlib
import h5py
import numpy as np
def worker(task):
# Do something with the task and return output data!
i, x, cache_file = task
return i, x**2 + 2, cache_file
def callback(result):
i, value, cache_file = result
with h5py.File(cache_file, "r+") as f:
f["data"][i] = value
def main(pool, overwrite):
cache_file = pathlib.Path("test-cache-file.hdf5")
# One thing you need to know off the bat is how many total tasks there will be:
N_tasks_total = 10_000
# Set up the cache file to make sure it exists. You'll have to modify this so it
# creates the data structure you will be using:
if not cache_file.exists() or overwrite:
print("Cache file doesn't exist")
with h5py.File(cache_file, "w") as f:
# Fill the cached data with nan values:
f["data"] = np.full(N_tasks_total, np.nan)
with h5py.File(cache_file, "r") as f:
# This is an array of indices of tasks that have not yet been completed:
idx = np.where(np.logical_not(np.isfinite(f["data"])))[0]
# Make some fake input data from random numbers:
input_data = np.random.uniform(0, 100, size=N_tasks_total)
index_array = np.arange(N_tasks_total)
tasks = [(i, x, cache_file) for i, x in zip(index_array[idx], input_data[idx])]
for _ in pool.map(worker, tasks, callback=callback):
pass
with h5py.File(cache_file, "r") as f:
print(f["data"][:])
if __name__ == "__main__":
import sys
from argparse import ArgumentParser
from schwimmbad.mpi import MPIPool
# Define parser object
parser = ArgumentParser()
parser.add_argument("--overwrite", default=False, action="store_true")
args = parser.parse_args()
with MPIPool() as pool:
main(pool=pool, overwrite=args.overwrite)
sys.exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment