Skip to content

Instantly share code, notes, and snippets.

@anj-s
Created April 28, 2021 15:25
Show Gist options
  • Save anj-s/b2a68d7985dc7ebd8a6d95d7e7094170 to your computer and use it in GitHub Desktop.
Save anj-s/b2a68d7985dc7ebd8a6d95d7e7094170 to your computer and use it in GitHub Desktop.
Example demonstrating torch.jit.script + rpc_async/rpc_sync + Rrefs
# Example repro for failing to profile a callback.
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import time
import argparse
RPC_PORT = 25001
class Server(object):
def __init__(self):
self.value = 32
@staticmethod
@rpc.functions.async_execution
@torch.jit.script
def val(myobj_rref):
self = myobj_rref.local_value()
fut = torch.futures.Future()
fut.set_result(self.value)
return fut
def rpc_worker(rank, world_size, func, args):
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = f"tcp://localhost:{RPC_PORT}"
rpc.init_rpc(
"worker" + str(rank),
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc_backend_options
)
if rank == 0:
func(args)
rpc.shutdown()
def test_func(args):
myobj_rref = rpc.remote("worker1", Server)
fut = rpc.rpc_async(myobj_rref.owner(),
Server.val,
args=(myobj_rref,))
print(f"{fut.wait()}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use_torch_script", action="store_true", default=False)
args = parser.parse_args()
world_size = 2
mp.spawn(rpc_worker, args=(world_size, test_func, args), nprocs=world_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment