Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created April 11, 2022 18:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesr66a/f8f545a987bbc6ac0b904b8b3c8ff2c6 to your computer and use it in GitHub Desktop.
Save jamesr66a/f8f545a987bbc6ac0b904b8b3c8ff2c6 to your computer and use it in GitHub Desktop.
import argparse, socket, os
import torch
import torch.fx
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
def run_main(args):
class MyCode(torch.nn.Module):
def forward(self, x):
for _ in range(1000):
x = torch.relu(x)
return x
traced = torch.fx.symbolic_trace(MyCode())
rpc.remote(to=1, func=print, args=(traced,))
rpc.remote(to=1, func=print, args=(traced,))
def run_worker(rank, world_size, args):
print(f"rank = {rank} host/pid = {socket.gethostname()}/{os.getpid()}")
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256)
rpc.init_rpc(
f"worker{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
if rank == 0:
run_main(args)
rpc.shutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 3)))
parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))
args = parser.parse_args()
if args.rank == -1:
mp.spawn(run_worker, args=(args.world_size, args,), nprocs=args.world_size, join=True)
elif args.rank < args.world_size:
run_worker(args.rank, args.world_size, args)
else:
print("I'm unused, exiting")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment