Skip to content

Instantly share code, notes, and snippets.

@anj-s
Created April 28, 2021 18:03
Show Gist options
  • Save anj-s/8b03b351414e1a59db151f1bda3c6436 to your computer and use it in GitHub Desktop.
Save anj-s/8b03b351414e1a59db151f1bda3c6436 to your computer and use it in GitHub Desktop.
Repro failing to print when using a profiler within a callback.
# Example repro for failing to profile a callback.
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import time
import argparse
RPC_PORT = 25001
def rpc_worker(rank, world_size, func, args):
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = f"tcp://localhost:{RPC_PORT}"
rpc.init_rpc(
"worker" + str(rank),
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc_backend_options
)
if rank == 0:
func(args)
rpc.shutdown()
def remote_func():
time.sleep(5)
return 32
def callback(fut):
print(f"Future value is {fut.wait()}")
def callback_with_profiler(fut):
with torch.autograd.profiler(use_cuda=True) as prof:
# ERROR: Returns without printing this.
print(f"Future value is {fut.wait()}")
# No file present and no error thrown.
prof.export_chrome_trace("repro_rpc_profiler")
def test_func(args):
return_future = rpc.rpc_async("worker1", remote_func)
if args.use_profiler:
return_future.then(callback_with_profiler)
else:
return_future.then(callback)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use_profiler", action="store_true", default=False)
args = parser.parse_args()
world_size = 2
mp.spawn(rpc_worker, args=(world_size, test_func, args), nprocs=world_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment