Skip to content

Instantly share code, notes, and snippets.

@froody
Created September 23, 2020 23:03
Show Gist options
  • Save froody/c9e6749e877e0f0edfe412d079d73719 to your computer and use it in GitHub Desktop.
Save froody/c9e6749e877e0f0edfe412d079d73719 to your computer and use it in GitHub Desktop.
reduced test case for all_reduce failure
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_ucc
def worker(rank, world_size):
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
dist.init_process_group("ucc", rank=rank, world_size=world_size)
device = torch.device("cuda", rank)
t = torch.ones(2, device=device)
dist.all_reduce(t)
torch.cuda.synchronize()
if "OMPI_COMM_WORLD_RANK" in os.environ:
print(f"hello test yay")
worker(int(os.environ["OMPI_COMM_WORLD_RANK"]), int(os.environ["OMPI_COMM_WORLD_SIZE"]))
elif __name__ == "__main__":
world_size = 8
mp.spawn(worker, args=(world_size,), nprocs=world_size, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment