Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Created May 9, 2024 05:19
Show Gist options
  • Save merrymercy/6b271ee8ef3bcd64b935f1a3b0feff50 to your computer and use it in GitHub Desktop.
Save merrymercy/6b271ee8ef3bcd64b935f1a3b0feff50 to your computer and use it in GitHub Desktop.
# mpirun -np 2 python p2p-nonblocking.py
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
import jax.numpy as jnp
nbytes = 1024*1024*32
data_type = cp.float32
buffsize = nbytes
os.environ["NCCL_BUFFSIZE"] = str(buffsize)
os.environ["NCCL_P2P_NVL_CHUNKSIZE"] = str(buffsize)
os.environ["NCCL_P2P_NET_CHUNKSIZE"] = str(buffsize)
os.environ["NCCL_MAX_NCHANNELS"] = "1"
os.environ["NCCL_DEBUG"] = "INFO"
def run_benchmark(mpi_comm, nccl_comm):
if data_type == cp.float32:
nccl_dtype = nccl.NCCL_FLOAT32
nbytes_per_elem = 4
nelem = nbytes // nbytes_per_elem
memory = cp.zeros(nelem, dtype=data_type)
stream = cp.cuda.Stream(non_blocking=True)
stream2 = cp.cuda.Stream(non_blocking=True)
# warmup to just make the connections
nccl.groupStart()
if mpi_comm.rank == 0:
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
elif mpi_comm.rank == 1:
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl.groupEnd()
a_cp = cp.ones((1024, 1024))
b_cp = cp.ones((1024, 1024))
a_jnp = jnp.ones((1024, 1024))
b_jnp = jnp.ones((1024, 1024))
# warmup
_ = cp.ones(10) + cp.ones(10)
_ = jnp.ones(10) + jnp.ones(10)
cp.cuda.runtime.deviceSynchronize()
mpi_comm.barrier()
st = time.time()
if mpi_comm.rank == 0:
pass
# time.sleep(5)
# nccl_comm.send(
# memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
# )
elif mpi_comm.rank == 1:
print("recv begin")
# nccl.groupStart()
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
# nccl.groupEnd()
print("recv end")
print("compute begin")
#with stream2:
# c = a_cp + b_cp
with stream2:
c = a_jnp + b_jnp
print("compute end", flush=True)
mpi_comm.barrier()
cp.cuda.runtime.deviceSynchronize()
en = time.time()
print(f"{mpi_comm.rank} took {en-st} seconds")
def create_nccl_comm(mpi_comm):
root = 0
if mpi_comm.rank == root:
uid = nccl.get_unique_id()
else:
uid = None
uid = mpi_comm.bcast(uid, root=root)
cp.cuda.runtime.deviceSynchronize()
tic = time.time()
comm = nccl.NcclCommunicator(mpi_comm.size, uid, mpi_comm.rank)
cp.cuda.runtime.deviceSynchronize()
print(f"communicator cost: {time.time() - tic:.2f}s")
return comm
if __name__ == "__main__":
world_comm = MPI.COMM_WORLD
world_rank = world_comm.rank
world_size = world_comm.size
nccl_comm = None
assert world_size == 2
try:
#cp.cuda.Device(world_rank).use()
os.environ["CUDA_VISIBLE_DEVICES"] = f"{world_rank}"
cp.cuda.Device(0).use()
nccl_comm = create_nccl_comm(world_comm)
run_benchmark(world_comm, nccl_comm)
nccl_comm = None
MPI.Finalize()
world_comm = None
except Exception as e:
print(f"An error occurred: {e}")
if nccl_comm:
nccl_comm.abort()
world_comm.Abort()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment