Skip to content

Instantly share code, notes, and snippets.

@anj-s
Created April 28, 2021 16:53
Show Gist options
  • Save anj-s/87fcd2b504b29642e0fc00639e776d8e to your computer and use it in GitHub Desktop.
Save anj-s/87fcd2b504b29642e0fc00639e776d8e to your computer and use it in GitHub Desktop.
Monotonically increasing bucket RTTs in parameter servers.
# Repro increasing bucket RTTs.
import argparse
import os
import socket
import threading
import subprocess
import time
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions
RPC_PORT = 25001
def get_init_urls(args):
if args.slurm_job:
node_list = os.environ.get("SLURM_JOB_NODELIST")
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", node_list]
)
master_host = hostnames.split()[0].decode("utf-8")
# Each worker has it's own process group hence we create
# a local URL for PG connections.
url_rpc = f"tcp://{master_host}:{RPC_PORT}"
else:
url_rpc = f"tcp://localhost:{RPC_PORT}"
return url_rpc
class LocalParameterServer(object):
def __init__(self, num_local_workers, args):
self.lock = threading.Lock()
self.curr_update_size = 0
self.num_local_workers = num_local_workers
self.grads = None
# Replace with num_buckets
self.fut_list = [None for i in range(args.num_buckets)]
self.use_cuda_tensors = args.use_cuda_tensors
@staticmethod
@rpc.functions.async_execution
def agg_grads(ps_rref, grads, local_worker_id, bucket_id):
self = ps_rref.local_value()
with self.lock:
self.curr_update_size += 1
self.fut_list[bucket_id] = torch.futures.Future()
if not self.grads:
self.grads = grads
else:
self.grads += grads
if self.curr_update_size == self.num_local_workers:
if args.use_cuda_tensors:
index_tensor = torch.tensor([1.], device="cuda:0") * bucket_id
else:
index_tensor = torch.tensor([1.]) * bucket_id
self.fut_list[bucket_id].set_result(torch.cat((index_tensor, self.grads)))
self.curr_update_size = 0
self.grads = None
return self.fut_list[bucket_id]
class Trainer(object):
def __init__(self, local_ps_rref, num_buckets=3, bucket_size=1, use_cuda_tensors=False):
self.local_ps_rref = local_ps_rref
self.num_buckets = num_buckets
self.timer = [None for i in range(self.num_buckets)]
self.fut_list = []
self.bucket_size = bucket_size
self.use_cuda_tensors = use_cuda_tensors
self.sent = False
def comm_hook(self):
# name = rpc.get_worker_info().name
id = rpc.get_worker_info().id
num_elems = self.bucket_size * 2**20 // 4
if self.use_cuda_tensors:
bucket_val = torch.ones(num_elems, device="cuda:0")
else:
bucket_val = torch.ones(num_elems)
def callback(fut):
inter_fut = fut.wait()
bucket_index = int(inter_fut[0])
# result_list.append(inter_fut[1:])
print(f"Callback triggered in {(time.perf_counter_ns() - self.timer[bucket_index])/1e6} ms")
for i in range(self.num_buckets):
fut = rpc.rpc_async(
self.local_ps_rref.owner(),
LocalParameterServer.agg_grads,
args=(self.local_ps_rref, bucket_val, id, i),
)
self.timer[i] = time.perf_counter_ns()
fut.then(callback)
def run_trainer(local_ps_rref, args):
trainer = Trainer(local_ps_rref, args.num_buckets, args.bucket_size, use_cuda_tensors=args.use_cuda_tensors)
print("---Warm Up-----")
trainer.comm_hook()
time.sleep(10)
print("-----Run-------")
trainer.comm_hook()
def run_ps(rank, world_size, args):
r"""
A wrapper function that initializes RPC, calls the function, and shuts down
RPC.
"""
url_rpc = get_init_urls(args)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = url_rpc
rpc_backend_options.set_device_map("worker1", {0: 0})
print(f"run_ps {rank} with world size {world_size}")
rpc.init_rpc(
f"ps{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
local_ps_rref = rpc.RRef(LocalParameterServer(1, args))
futs = []
futs.append(
rpc.rpc_async("worker1", run_trainer, args=(local_ps_rref, args,))
)
torch.futures.wait_all(futs)
# block until all rpcs finish
rpc.shutdown()
def run_worker(rank, world_size, args):
r"""
A wrapper function that initializes RPC, calls the function, and shuts down
RPC.
"""
url_rpc = get_init_urls(args)
rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = url_rpc
rpc_backend_options.set_device_map("ps0", {0: 0})
print(f"run_worker {rank} with world size {world_size}")
rpc.init_rpc(
f"worker1",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options,
)
# block until all rpcs finish
rpc.shutdown()
def run_local(local_rank, world_size, args):
if local_rank == 0:
run_ps(local_rank, world_size, args)
else:
run_worker(local_rank, world_size, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Reproduce bucket send/recv latency."
)
parser.add_argument(
"--slurm_job",
action="store_true",
default=False,
help="""Local or SLUM job. By default we assume local job.""",
)
parser.add_argument(
"--backend",
type=str,
default="nccl",
help="""The communication backend to use for the collectives operation.""",
)
parser.add_argument(
"--num_devices",
type=int,
default=1,
help="""Number of GPUs on a given machine.""",
)
parser.add_argument(
"--num_buckets",
type=int,
default=100,
help="""Number of buckets.""",
)
parser.add_argument(
"--bucket_size",
type=int,
default=1,
help="""Size of the bucket in MB.""",
)
parser.add_argument(
"--use_cuda_tensors",
action="store_true",
default=False,
help="""Use CUDA tensors.""",
)
args = parser.parse_args()
if "SLURM_NODEID" in os.environ:
args.slurm_job = True
os.environ["TP_SOCKET_IFNAME"] = "front0"
os.environ["GLOO_SOCKET_IFNAME"] = "front0"
os.environ["NCCL_SOCKET_IFNAME"] = "front0"
os.environ["NCCL_DEBUG"] = "INFO"
# Enabling torch.cuda before calling spawn raises an error. Hence setting
# this with a flag for now.
gpus_per_node = args.num_devices
if args.slurm_job:
node_id = int(os.environ.get("SLURM_NODEID"))
num_nodes = int(os.environ.get("SLURM_NNODES"))
else:
args.num_devices = 2
num_nodes = 1
args.num_nodes = num_nodes
args.world_size = 2
if args.slurm_job and node_id == 0:
run_ps(node_id, args.world_size, args)
elif args.slurm_job and node_id == 1:
run_worker(node_id, args.world_size, args)
else:
torch.multiprocessing.spawn(
run_local, args=(args.world_size, args,), nprocs=2, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment