Skip to content

Instantly share code, notes, and snippets.

@froody
Created September 22, 2020 23:11
Show Gist options
  • Save froody/6286597d33849ff8a108831c31ccd66b to your computer and use it in GitHub Desktop.
Save froody/6286597d33849ff8a108831c31ccd66b to your computer and use it in GitHub Desktop.
example crash in torch-ucc
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_ucc
def worker(rank, world_size):
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
dist.init_process_group("ucc", rank=rank, world_size=world_size)
print(f"rank is = {dist.get_rank()}")
t = torch.ones(10, device="cuda")
dist.all_reduce(t)
if rank == 0:
print(f"t is {t}")
t2 = torch.rand(10)
print(f"sending {t2}")
dist.send(t2, 1)
elif rank == 1:
t3 = torch.empty(10)
dist.recv(t3, src=0)
print(f"recvd {t3}")
dist.barrier()
print(f"creating group 0,1")
dist.new_group([0,1])
print(f"creating group 0")
dist.new_group([0])
print(f"creating group 1")
dist.new_group([1])
if "OMPI_COMM_WORLD_RANK" in os.environ:
worker(int(os.environ["OMPI_COMM_WORLD_RANK"]), int(os.environ["OMPI_COMM_WORLD_SIZE"]))
elif __name__ == "__main__":
world_size = 2
mp.spawn(worker, args=(world_size,), nprocs=world_size, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment