Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Created May 9, 2024 19:58
Show Gist options
  • Save merrymercy/135bb3216de6a676dba972aa7684eaa1 to your computer and use it in GitHub Desktop.
Save merrymercy/135bb3216de6a676dba972aa7684eaa1 to your computer and use it in GitHub Desktop.
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
import jax.numpy as jnp
nbytes = 1024*1024*32
data_type = cp.float32
buffsize = nbytes
os.environ["NCCL_BUFFSIZE"] = str(buffsize)
os.environ["NCCL_P2P_NVL_CHUNKSIZE"] = str(buffsize)
os.environ["NCCL_P2P_NET_CHUNKSIZE"] = str(buffsize)
os.environ["NCCL_MAX_NCHANNELS"] = "1"
os.environ["NCCL_DEBUG"] = "INFO"
def run_benchmark(mpi_comm, nccl_comm):
if data_type == cp.float32:
nccl_dtype = nccl.NCCL_FLOAT32
nbytes_per_elem = 4
nelem = nbytes // nbytes_per_elem
memory = cp.zeros(nelem, dtype=data_type)
memory2 = cp.zeros(nelem, dtype=data_type)
stream = cp.cuda.Stream(non_blocking=True)
stream2 = cp.cuda.Stream(non_blocking=True)
# warmup to just make the connections
nccl.groupStart()
if mpi_comm.rank == 0:
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
elif mpi_comm.rank == 1:
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl.groupEnd()
a_cp = cp.ones((1024, 1024))
b_cp = cp.ones((1024, 1024))
a_jnp = jnp.ones((1024, 1024))
b_jnp = jnp.ones((1024, 1024))
# warmup
_ = cp.ones(10) + cp.ones(10)
c = a_jnp + b_jnp
cp.cuda.runtime.deviceSynchronize()
mpi_comm.barrier()
st = time.time()
if mpi_comm.rank == 0:
# pass
time.sleep(5)
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
nccl_comm.send(
memory2.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
elif mpi_comm.rank == 1:
print("recv begin")
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
e = stream.record()
print("recv end")
print("compute begin")
with stream2:
c = a_cp + b_cp
stream2.synchronize()
print("compute end", flush=True)
mpi_comm.barrier()
cp.cuda.runtime.deviceSynchronize()
en = time.time()
print(f"{mpi_comm.rank} took {en-st} seconds")
def create_nccl_comm(mpi_comm):
root = 0
if mpi_comm.rank == root:
uid = nccl.get_unique_id()
else:
uid = None
uid = mpi_comm.bcast(uid, root=root)
cp.cuda.runtime.deviceSynchronize()
tic = time.time()
comm = nccl.NcclCommunicator(mpi_comm.size, uid, mpi_comm.rank)
cp.cuda.runtime.deviceSynchronize()
print(f"communicator cost: {time.time() - tic:.2f}s")
return comm
if __name__ == "__main__":
world_comm = MPI.COMM_WORLD
world_rank = world_comm.rank
world_size = world_comm.size
nccl_comm = None
assert world_size == 2
try:
#cp.cuda.Device(world_rank).use()
os.environ["CUDA_VISIBLE_DEVICES"] = f"{world_rank}"
cp.cuda.Device(0).use()
nccl_comm = create_nccl_comm(world_comm)
run_benchmark(world_comm, nccl_comm)
nccl_comm = None
MPI.Finalize()
world_comm = None
except Exception as e:
print(f"An error occurred: {e}")
if nccl_comm:
nccl_comm.abort()
world_comm.Abort()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment