Skip to content

Instantly share code, notes, and snippets.

@betatim
Last active September 22, 2023 07:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save betatim/94840c772380bd7db8b1d1d222c2187a to your computer and use it in GitHub Desktop.
Save betatim/94840c772380bd7db8b1d1d222c2187a to your computer and use it in GitHub Desktop.
Working out how to `mpirun` dask with cuda
from dask_mpi import initialize
from dask import distributed
def dask_info():
distributed.print("woah i'm running!")
distributed.print("ncores:", client.ncores())
distributed.print()
distributed.print(client.scheduler_info())
def square(x):
return x ** 2
def neg(x):
return -x
if __name__ == "__main__":
# MPI Ranks 1-n will be used for the Dask scheduler and workers
# and will not progress beyond this initialization call
initialize(worker_class="dask_cuda.CUDAWorker",
worker_options={"enable_tcp_over_ucx": False,
"enable_infiniband": False,
"enable_nvlink": False,}
)
# MPI Rank 0 will continue executing the script once the scheduler has started
from dask.distributed import Client
client = Client() # The scheduler address is found automatically via MPI
client.wait_for_workers(2)
dask_info()
A = client.map(square, range(10))
B = client.map(neg, A)
total = client.submit(sum, B)
distributed.print("total:", total.result())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment