-
-
Save froody/01ed6ce8d6ab72bd868431d793591379 to your computer and use it in GitHub Desktop.
benchmark for all-reduce
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import torch_ucc | |
import time | |
def worker(rank, world_size): | |
os.environ["RANK"] = str(rank) | |
os.environ["WORLD_SIZE"] = str(world_size) | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "10638" | |
backend = os.environ.get("BACKEND", "nccl") | |
dist.init_process_group(backend, rank=rank, world_size=world_size) | |
device = torch.device("cuda", rank) | |
torch.cuda.set_device(rank) | |
size = 2 | |
while True: | |
t = torch.rand([size], device=device) | |
torch.cuda.synchronize() | |
start = time.monotonic() | |
dist.all_reduce(t) | |
torch.cuda.synchronize() | |
end = time.monotonic() | |
elapsedTimeS = (end - start) | |
rate = (torch.numel(t)*t.element_size())/elapsedTimeS | |
if rank == 0: | |
print(f"rate for {size} = {rate} b/s") | |
size *= 2 | |
t = torch.randint(0, 100, (2,), device=device) | |
dist.all_reduce(t) | |
torch.cuda.synchronize() | |
if "OMPI_COMM_WORLD_RANK" in os.environ: | |
worker(int(os.environ["OMPI_COMM_WORLD_RANK"]), int(os.environ["OMPI_COMM_WORLD_SIZE"])) | |
elif __name__ == "__main__": | |
world_size = 8 | |
mp.spawn(worker, args=(world_size,), nprocs=world_size, join=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment