Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Last active May 9, 2024 11:53
Show Gist options
  • Save merrymercy/7fd8e3871733bff0a1aee7b685c4172b to your computer and use it in GitHub Desktop.
Save merrymercy/7fd8e3871733bff0a1aee7b685c4172b to your computer and use it in GitHub Desktop.
# mpirun -np 2 python p2p-nonblocking.py
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
import jax.numpy as jnp
nbytes = 1024*1024*32
data_type = cp.float32
buffsize = nbytes
os.environ["NCCL_BUFFSIZE"] = str(buffsize)
os.environ["NCCL_P2P_NVL_CHUNKSIZE"] = str(buffsize)
os.environ["NCCL_P2P_NET_CHUNKSIZE"] = str(buffsize)
os.environ["NCCL_MAX_NCHANNELS"] = "1"
os.environ["NCCL_DEBUG"] = "INFO"
import jax
@jax.jit
def compute(a, b):
ds = []
for i in range(10):
a = a + 1
b = b + 1
c = a @ b
d = c + 1
ds.append(d)
return sum(ds)
def run_benchmark(mpi_comm, nccl_comm):
if data_type == cp.float32:
nccl_dtype = nccl.NCCL_FLOAT32
nbytes_per_elem = 4
nelem = nbytes // nbytes_per_elem
memory = cp.zeros(nelem, dtype=data_type)
stream = cp.cuda.Stream(non_blocking=True)
stream2 = cp.cuda.Stream(non_blocking=True)
# warmup to just make the connections
nccl.groupStart()
if mpi_comm.rank == 0:
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
elif mpi_comm.rank == 1:
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl.groupEnd()
a_cp = cp.ones((1024, 1024))
b_cp = cp.ones((1024, 1024))
a_jnp = jnp.ones((1024, 1024))
b_jnp = jnp.ones((1024, 1024))
aa_jnp = jnp.ones((1024, 1024))
bb_jnp = jnp.ones((1024, 1024))
# warmup
_ = cp.ones(10) + cp.ones(10)
compute(aa_jnp, bb_jnp)
cp.cuda.runtime.deviceSynchronize()
mpi_comm.barrier()
st = time.time()
if mpi_comm.rank == 0:
time.sleep(10)
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
nccl_comm.send(
memory.data.ptr, nelem, nccl_dtype, 1, stream.ptr
)
elif mpi_comm.rank == 1:
print("recv begin")
nccl.groupStart()
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl.groupEnd()
print("recv end")
print("overlap begin", flush=True)
c = compute(a_jnp, b_jnp)
nccl.groupStart()
nccl_comm.recv(
memory.data.ptr, nelem, nccl_dtype, 0, stream.ptr
)
nccl.groupEnd()
d = compute(a_jnp, b_jnp)
c.block_until_ready()
d.block_until_ready()
print("overlap end", flush=True)
mpi_comm.barrier()
cp.cuda.runtime.deviceSynchronize()
en = time.time()
print(f"{mpi_comm.rank} took {en-st} seconds")
def create_nccl_comm(mpi_comm):
root = 0
if mpi_comm.rank == root:
uid = nccl.get_unique_id()
else:
uid = None
uid = mpi_comm.bcast(uid, root=root)
cp.cuda.runtime.deviceSynchronize()
tic = time.time()
comm = nccl.NcclCommunicator(mpi_comm.size, uid, mpi_comm.rank)
cp.cuda.runtime.deviceSynchronize()
print(f"communicator cost: {time.time() - tic:.2f}s")
return comm
if __name__ == "__main__":
world_comm = MPI.COMM_WORLD
world_rank = world_comm.rank
world_size = world_comm.size
nccl_comm = None
assert world_size == 2
try:
#cp.cuda.Device(world_rank).use()
os.environ["CUDA_VISIBLE_DEVICES"] = f"{world_rank}"
cp.cuda.Device(0).use()
nccl_comm = create_nccl_comm(world_comm)
run_benchmark(world_comm, nccl_comm)
nccl_comm = None
MPI.Finalize()
world_comm = None
except Exception as e:
print(f"An error occurred: {e}")
if nccl_comm:
nccl_comm.abort()
world_comm.Abort()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment