Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Created May 9, 2024 14:06
Show Gist options
  • Save merrymercy/b98fee8c9377f27ffbe5bb38b7480882 to your computer and use it in GitHub Desktop.
Save merrymercy/b98fee8c9377f27ffbe5bb38b7480882 to your computer and use it in GitHub Desktop.
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
import jax.numpy as jnp
nbytes = 1024*1024*32
data_type = cp.float32
buffsize = nbytes
os.environ["NCCL_BUFFSIZE"] = str(buffsize)
os.environ["NCCL_P2P_NVL_CHUNKSIZE"] = str(buffsize)
os.environ["NCCL_P2P_NET_CHUNKSIZE"] = str(buffsize)
os.environ["NCCL_MAX_NCHANNELS"] = "1"
os.environ["NCCL_DEBUG"] = "INFO"
def run_benchmark(mpi_comm, nccl_comm):
if data_type == cp.float32:
nccl_dtype = nccl.NCCL_FLOAT32
nbytes_per_elem = 4
nelem = nbytes // nbytes_per_elem
memory = cp.zeros(nelem, dtype=data_type)
memory2 = cp.zeros(nelem, dtype=data_type)
stream = cp.cuda.Stream(non_blocking=True)
stream2 = cp.cuda.Stream(non_blocking=True)
# warmup to just make the connections
nccl.groupStart()
if mpi_comm.rank == 0:
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
elif mpi_comm.rank == 1:
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl.groupEnd()
a_cp = cp.ones((1024, 1024))
b_cp = cp.ones((1024, 1024))
a_jnp = jnp.ones((1024, 1024))
b_jnp = jnp.ones((1024, 1024))
# warmup
_ = cp.ones(10) + cp.ones(10)
c = a_jnp + b_jnp
cp.cuda.runtime.deviceSynchronize()
mpi_comm.barrier()
st = time.time()
if mpi_comm.rank == 0:
# pass
time.sleep(5)
#nccl_comm.send(
# memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
#)
#nccl_comm.send(
# memory2.data.ptr, nelem, nccl_dtype, 1, stream.ptr
#)
elif mpi_comm.rank == 1:
print("recv begin")
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl_comm.recv(
memory2.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
print("recv end")
print("compute begin")
c = a_jnp + b_jnp
c.block_until_ready()
print("compute end", flush=True)
mpi_comm.barrier()
cp.cuda.runtime.deviceSynchronize()
en = time.time()
print(f"{mpi_comm.rank} took {en-st} seconds")
def create_nccl_comm(mpi_comm):
root = 0
if mpi_comm.rank == root:
uid = nccl.get_unique_id()
else:
uid = None
uid = mpi_comm.bcast(uid, root=root)
cp.cuda.runtime.deviceSynchronize()
tic = time.time()
comm = nccl.NcclCommunicator(mpi_comm.size, uid, mpi_comm.rank)
cp.cuda.runtime.deviceSynchronize()
print(f"communicator cost: {time.time() - tic:.2f}s")
return comm
if __name__ == "__main__":
world_comm = MPI.COMM_WORLD
world_rank = world_comm.rank
world_size = world_comm.size
nccl_comm = None
assert world_size == 2
try:
#cp.cuda.Device(world_rank).use()
os.environ["CUDA_VISIBLE_DEVICES"] = f"{world_rank}"
cp.cuda.Device(0).use()
nccl_comm = create_nccl_comm(world_comm)
run_benchmark(world_comm, nccl_comm)
nccl_comm = None
MPI.Finalize()
world_comm = None
except Exception as e:
print(f"An error occurred: {e}")
if nccl_comm:
nccl_comm.abort()
world_comm.Abort()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment