Skip to content

Instantly share code, notes, and snippets.

@froody
Last active September 24, 2020 21:09
Show Gist options
  • Save froody/01ed6ce8d6ab72bd868431d793591379 to your computer and use it in GitHub Desktop.
Save froody/01ed6ce8d6ab72bd868431d793591379 to your computer and use it in GitHub Desktop.
benchmark for all-reduce
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_ucc
import time
def worker(rank, world_size):
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
backend = os.environ.get("BACKEND", "nccl")
dist.init_process_group(backend, rank=rank, world_size=world_size)
device = torch.device("cuda", rank)
torch.cuda.set_device(rank)
size = 2
while True:
t = torch.rand([size], device=device)
torch.cuda.synchronize()
start = time.monotonic()
dist.all_reduce(t)
torch.cuda.synchronize()
end = time.monotonic()
elapsedTimeS = (end - start)
rate = (torch.numel(t)*t.element_size())/elapsedTimeS
if rank == 0:
print(f"rate for {size} = {rate} b/s")
size *= 2
t = torch.randint(0, 100, (2,), device=device)
dist.all_reduce(t)
torch.cuda.synchronize()
if "OMPI_COMM_WORLD_RANK" in os.environ:
worker(int(os.environ["OMPI_COMM_WORLD_RANK"]), int(os.environ["OMPI_COMM_WORLD_SIZE"]))
elif __name__ == "__main__":
world_size = 8
mp.spawn(worker, args=(world_size,), nprocs=world_size, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment