Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@froody
Created October 14, 2020 19:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save froody/7bb2b5971944833cd201ef675c0622fb to your computer and use it in GitHub Desktop.
Save froody/7bb2b5971944833cd201ef675c0622fb to your computer and use it in GitHub Desktop.
import torch
from torch import distributed as dist
from torch.distributed import rpc
group = None
def sender():
global group
t = torch.rand(10, device=torch.device("cuda", group.rank()))
print(f">>> send {group.rank()}")
dist.send(t, 1, group=group)
print(f"<<< send")
def receiver():
global group
t = torch.empty(10, device=torch.device("cuda", group.rank()))
print(f">>> recv {group.rank()}")
dist.recv(t, 0, group=group)
print(f"<<< recv")
def main():
dist.init_process_group("mpi")
host = "localhost"
port = "10639"
rank = dist.get_rank()
init_method = f"tcp://{host}:{port}"
from torch.distributed import rpc
global group
group = dist.new_group([0,1])
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=dist.get_world_size(),
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(init_method=init_method),
)
if rank == 0:
rpc.rpc_async("Test1", receiver)
else:
rpc.rpc_async("Test0", sender)
rpc.shutdown()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment