Created
April 28, 2021 16:53
-
-
Save anj-s/87fcd2b504b29642e0fc00639e776d8e to your computer and use it in GitHub Desktop.
Monotonically increasing bucket RTTs in parameter servers.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Repro increasing bucket RTTs. | |
import argparse | |
import os | |
import socket | |
import threading | |
import subprocess | |
import time | |
import torch | |
import torch.distributed as dist | |
import torch.distributed.rpc as rpc | |
import torch.multiprocessing as mp | |
from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions | |
RPC_PORT = 25001 | |
def get_init_urls(args): | |
if args.slurm_job: | |
node_list = os.environ.get("SLURM_JOB_NODELIST") | |
hostnames = subprocess.check_output( | |
["scontrol", "show", "hostnames", node_list] | |
) | |
master_host = hostnames.split()[0].decode("utf-8") | |
# Each worker has it's own process group hence we create | |
# a local URL for PG connections. | |
url_rpc = f"tcp://{master_host}:{RPC_PORT}" | |
else: | |
url_rpc = f"tcp://localhost:{RPC_PORT}" | |
return url_rpc | |
class LocalParameterServer(object): | |
def __init__(self, num_local_workers, args): | |
self.lock = threading.Lock() | |
self.curr_update_size = 0 | |
self.num_local_workers = num_local_workers | |
self.grads = None | |
# Replace with num_buckets | |
self.fut_list = [None for i in range(args.num_buckets)] | |
self.use_cuda_tensors = args.use_cuda_tensors | |
@staticmethod | |
@rpc.functions.async_execution | |
def agg_grads(ps_rref, grads, local_worker_id, bucket_id): | |
self = ps_rref.local_value() | |
with self.lock: | |
self.curr_update_size += 1 | |
self.fut_list[bucket_id] = torch.futures.Future() | |
if not self.grads: | |
self.grads = grads | |
else: | |
self.grads += grads | |
if self.curr_update_size == self.num_local_workers: | |
if args.use_cuda_tensors: | |
index_tensor = torch.tensor([1.], device="cuda:0") * bucket_id | |
else: | |
index_tensor = torch.tensor([1.]) * bucket_id | |
self.fut_list[bucket_id].set_result(torch.cat((index_tensor, self.grads))) | |
self.curr_update_size = 0 | |
self.grads = None | |
return self.fut_list[bucket_id] | |
class Trainer(object): | |
def __init__(self, local_ps_rref, num_buckets=3, bucket_size=1, use_cuda_tensors=False): | |
self.local_ps_rref = local_ps_rref | |
self.num_buckets = num_buckets | |
self.timer = [None for i in range(self.num_buckets)] | |
self.fut_list = [] | |
self.bucket_size = bucket_size | |
self.use_cuda_tensors = use_cuda_tensors | |
self.sent = False | |
def comm_hook(self): | |
# name = rpc.get_worker_info().name | |
id = rpc.get_worker_info().id | |
num_elems = self.bucket_size * 2**20 // 4 | |
if self.use_cuda_tensors: | |
bucket_val = torch.ones(num_elems, device="cuda:0") | |
else: | |
bucket_val = torch.ones(num_elems) | |
def callback(fut): | |
inter_fut = fut.wait() | |
bucket_index = int(inter_fut[0]) | |
# result_list.append(inter_fut[1:]) | |
print(f"Callback triggered in {(time.perf_counter_ns() - self.timer[bucket_index])/1e6} ms") | |
for i in range(self.num_buckets): | |
fut = rpc.rpc_async( | |
self.local_ps_rref.owner(), | |
LocalParameterServer.agg_grads, | |
args=(self.local_ps_rref, bucket_val, id, i), | |
) | |
self.timer[i] = time.perf_counter_ns() | |
fut.then(callback) | |
def run_trainer(local_ps_rref, args): | |
trainer = Trainer(local_ps_rref, args.num_buckets, args.bucket_size, use_cuda_tensors=args.use_cuda_tensors) | |
print("---Warm Up-----") | |
trainer.comm_hook() | |
time.sleep(10) | |
print("-----Run-------") | |
trainer.comm_hook() | |
def run_ps(rank, world_size, args): | |
r""" | |
A wrapper function that initializes RPC, calls the function, and shuts down | |
RPC. | |
""" | |
url_rpc = get_init_urls(args) | |
rpc_backend_options = rpc.TensorPipeRpcBackendOptions() | |
rpc_backend_options.init_method = url_rpc | |
rpc_backend_options.set_device_map("worker1", {0: 0}) | |
print(f"run_ps {rank} with world size {world_size}") | |
rpc.init_rpc( | |
f"ps{rank}", | |
rank=rank, | |
world_size=world_size, | |
rpc_backend_options=rpc_backend_options, | |
) | |
local_ps_rref = rpc.RRef(LocalParameterServer(1, args)) | |
futs = [] | |
futs.append( | |
rpc.rpc_async("worker1", run_trainer, args=(local_ps_rref, args,)) | |
) | |
torch.futures.wait_all(futs) | |
# block until all rpcs finish | |
rpc.shutdown() | |
def run_worker(rank, world_size, args): | |
r""" | |
A wrapper function that initializes RPC, calls the function, and shuts down | |
RPC. | |
""" | |
url_rpc = get_init_urls(args) | |
rpc_backend_options = rpc.TensorPipeRpcBackendOptions() | |
rpc_backend_options.init_method = url_rpc | |
rpc_backend_options.set_device_map("ps0", {0: 0}) | |
print(f"run_worker {rank} with world size {world_size}") | |
rpc.init_rpc( | |
f"worker1", | |
rank=rank, | |
world_size=world_size, | |
rpc_backend_options=rpc_backend_options, | |
) | |
# block until all rpcs finish | |
rpc.shutdown() | |
def run_local(local_rank, world_size, args): | |
if local_rank == 0: | |
run_ps(local_rank, world_size, args) | |
else: | |
run_worker(local_rank, world_size, args) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Reproduce bucket send/recv latency." | |
) | |
parser.add_argument( | |
"--slurm_job", | |
action="store_true", | |
default=False, | |
help="""Local or SLUM job. By default we assume local job.""", | |
) | |
parser.add_argument( | |
"--backend", | |
type=str, | |
default="nccl", | |
help="""The communication backend to use for the collectives operation.""", | |
) | |
parser.add_argument( | |
"--num_devices", | |
type=int, | |
default=1, | |
help="""Number of GPUs on a given machine.""", | |
) | |
parser.add_argument( | |
"--num_buckets", | |
type=int, | |
default=100, | |
help="""Number of buckets.""", | |
) | |
parser.add_argument( | |
"--bucket_size", | |
type=int, | |
default=1, | |
help="""Size of the bucket in MB.""", | |
) | |
parser.add_argument( | |
"--use_cuda_tensors", | |
action="store_true", | |
default=False, | |
help="""Use CUDA tensors.""", | |
) | |
args = parser.parse_args() | |
if "SLURM_NODEID" in os.environ: | |
args.slurm_job = True | |
os.environ["TP_SOCKET_IFNAME"] = "front0" | |
os.environ["GLOO_SOCKET_IFNAME"] = "front0" | |
os.environ["NCCL_SOCKET_IFNAME"] = "front0" | |
os.environ["NCCL_DEBUG"] = "INFO" | |
# Enabling torch.cuda before calling spawn raises an error. Hence setting | |
# this with a flag for now. | |
gpus_per_node = args.num_devices | |
if args.slurm_job: | |
node_id = int(os.environ.get("SLURM_NODEID")) | |
num_nodes = int(os.environ.get("SLURM_NNODES")) | |
else: | |
args.num_devices = 2 | |
num_nodes = 1 | |
args.num_nodes = num_nodes | |
args.world_size = 2 | |
if args.slurm_job and node_id == 0: | |
run_ps(node_id, args.world_size, args) | |
elif args.slurm_job and node_id == 1: | |
run_worker(node_id, args.world_size, args) | |
else: | |
torch.multiprocessing.spawn( | |
run_local, args=(args.world_size, args,), nprocs=2, join=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment