Skip to content

Instantly share code, notes, and snippets.

@myleott
Created April 17, 2020 21:07
Show Gist options
  • Save myleott/b5d6b5d2f0a9f3fc4e2a5797d41aa8c7 to your computer and use it in GitHub Desktop.
Save myleott/b5d6b5d2f0a9f3fc4e2a5797d41aa8c7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import time
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
def main(local_id):
parser = argparse.ArgumentParser()
parser.add_argument('--num-iters', type=int, default=100)
args = parser.parse_args()
rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
print('hello from rank {}/{}'.format(rank, world_size))
device = xm.xla_device()
if rank == 0:
print('bytes\t\treps\ttime\t\t\talgbw (GB/sec)\t\tbusbw (GB/sec)')
for size in [4194304, 16777216, 67108864, 268435456, 1073741824]:
# warmup (maybe not necessary)
tensor = torch.ones([size], device=device)
xm.all_reduce('sum', [tensor])
xm.mark_step()
start = time.time()
for _ in range(args.num_iters):
xm.all_reduce('sum', [tensor])
xm.mark_step()
end = time.time()
runtime = end - start
if rank == 0:
num_bytes = size * tensor.element_size()
algbw = (num_bytes * args.num_iters / runtime) / (1024 * 1024 * 1024)
# see https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md
busbw = algbw * 2.0 * (world_size - 1.0) / float(world_size)
print('{}\t{}\t{}\t{}\t{}'.format(num_bytes, args.num_iters, runtime, algbw, busbw))
xm.rendezvous('done_' + str(size))
if __name__ == '__main__':
xmp.spawn(
fn=main,
nprocs=8, # use all 8 TPU cores
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment