Skip to content

Instantly share code, notes, and snippets.

@wendazhou
Last active February 21, 2020 19:00
Show Gist options
  • Save wendazhou/f6d90d217d3b0ad93c3f94960135b3c7 to your computer and use it in GitHub Desktop.
Save wendazhou/f6d90d217d3b0ad93c3f94960135b3c7 to your computer and use it in GitHub Desktop.
Pytorch distributed pump
import torch
import time
def _send(buffer_size, dest):
x = torch.zeros((buffer_size,))
while True:
torch.distributed.send(x, dest)
def _receive(buffer_size):
received_since_last = 0
time_since_last = 0
x = torch.empty((buffer_size,))
x_item_size = torch.finfo(x.dtype).bits / 8 if x.dtype.is_floating_point else torch.iinfo(x.dtype).bits / 8
x_total_size_bytes = buffer_size * x_item_size
while True:
torch.distributed.recv(x)
received_since_last += 1
current_time = time.perf_counter()
elapsed = current_time - time_since_last
if elapsed > 0.5:
print('Current rate: {0} MB / s'.format(received_since_last * x_total_size_bytes / elapsed / (1024 * 1024)))
time_since_last = current_time
received_since_last = 0
def _run(rank, world_size, buffer_size):
torch.distributed.init_process_group(
'gloo', world_size=world_size, rank=rank,
init_method='tcp://pralexa1.cs.princeton.edu:16847')
rank = torch.distributed.get_rank()
if rank % 2 == 1:
_send(buffer_size, rank - 1)
else:
_receive(buffer_size)
def spawn(buffer_size):
torch.multiprocessing.spawn(
_run, nprocs=2,
args=(2, buffer_size),
join=True)
def main():
if not torch.distributed.is_available():
raise ValueError('Distributed operation is not available on this system!')
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--buffer_size', type=int, default=1024 * 1024)
parser.add_argument('--rank', type=int, default=None)
args = parser.parse_args()
if args.rank is None:
spawn(args.buffer_size)
else:
_run(args.rank, 2, args.buffer_size)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment