Skip to content

Instantly share code, notes, and snippets.

@ResidentMario
Last active August 11, 2021 12:09
Show Gist options
  • Save ResidentMario/dc542fc26a142a9dce85b258835c45ad to your computer and use it in GitHub Desktop.
Save ResidentMario/dc542fc26a142a9dce85b258835c45ad to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
def run(rank, size):
""" Distributed function to be implemented later. """
# collective ops are performed against groups
group = dist.new_group([0, 1, 2])
# For demo purposes, tensor reflects process rank
tensor = torch.tensor([rank])
# broadcast example
#
# broadcast synchronizes one process's copy of a tensor with all peer copies of that
# tensor.
#
# broadcast automatically barriers, e.g. it's guaranteed it will not return until all peer
# copies have recieved the update.
print(f"Value of rank {rank} copy of tensor before broadcast: {tensor}.\n")
dist.broadcast(tensor, 0, group)
print(f"Value of rank {rank} copy of tensor after broadcast: {tensor}.\n")
# all_reduce example
#
# all_reduce performs a global reduce op on tensor
tensor = torch.tensor([rank])
print(f"Value of rank {rank} copy of tensor before all_reduce: {tensor}.\n")
dist.all_reduce(tensor, dist.ReduceOp.SUM, group)
# 0 + 1 + 2 = 3
print(f"Value of rank {rank} copy of tensor after all_reduce: {tensor}.\n")
# reduce example
#
# reduce performs a local reduce op on a tensor, e.g., one which is only reflected in
# the process running the reduce. Which process sees the result is controlled by the
# addtl dest field.
tensor = torch.tensor([rank])
dest = 0
print(f"Value of rank {rank} copy of tensor before reduce: {tensor}.\n")
dist.reduce(tensor, dest, dist.ReduceOp.SUM, group)
print(f"Value of rank {rank} copy of tensor after reduce: {tensor}.\n")
# scatter example
#
# scatter distributes a list of values amongst the processes.
#
# some gotchas here are that scatter_list must be defined but None on all non-source
# ranks. that's not mentioned in the docs or docstring ffs!
tensor = torch.tensor([rank])
if rank == 0:
scatter_list = [torch.tensor([4]), torch.tensor([5]), torch.tensor([6])]
else:
scatter_list = None
print(f"Value of rank {rank} copy of tensor before scatter: {tensor}.\n")
dist.scatter(tensor, scatter_list, 0, group) # 0 is the src, e.g. the seeder
print(f"Value of rank {rank} copy of tensor after scatter: {tensor}.\n")
# gather example
#
# similar rules to scatter
tensor = torch.tensor([rank])
if rank == 0:
gather_list = [torch.tensor([9]), torch.tensor([9]), torch.tensor([9])]
else:
gather_list = None
print(f"Value of rank {rank} copy of gather_list before gather: {gather_list}.\n")
dist.gather(tensor, gather_list, 0, group) # 0 is the src, e.g. the seeder
print(f"Value of rank {rank} copy of gather_list after gather: {gather_list}.\n")
# last one, all-gather
tensor = torch.tensor([rank])
gather_list = [torch.tensor([9]), torch.tensor([9]), torch.tensor([9])]
print(f"Value of rank {rank} copy of gather_list before all_gather: {gather_list}.\n")
dist.all_gather(gather_list, tensor, group) # 0 is the src, e.g. the seeder
print(f"Value of rank {rank} copy of gather_list after all_gather: {gather_list}.\n")
print(f"Done with rank {rank} process!\n")
def init_process(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 3
processes = []
for rank in range(size):
p = Process(target=init_process, args=(rank, size, run))
p.start()
processes.append(p)
for p in processes:
p.join()
Value of rank 0 copy of tensor before broadcast: tensor([0]).
Value of rank 1 copy of tensor before broadcast: tensor([1]).
Value of rank 2 copy of tensor before broadcast: tensor([2]).
Value of rank 1 copy of tensor after broadcast: tensor([0]).
Value of rank 0 copy of tensor after broadcast: tensor([0]).
Value of rank 1 copy of tensor before all_reduce: tensor([1]).
Value of rank 0 copy of tensor before all_reduce: tensor([0]).
Value of rank 2 copy of tensor after broadcast: tensor([0]).
Value of rank 2 copy of tensor before all_reduce: tensor([2]).
Value of rank 2 copy of tensor after all_reduce: tensor([3]).
Value of rank 0 copy of tensor after all_reduce: tensor([3]).
Value of rank 0 copy of tensor before reduce: tensor([0]).
Value of rank 2 copy of tensor before reduce: tensor([2]).
Value of rank 1 copy of tensor after all_reduce: tensor([3]).
Value of rank 1 copy of tensor before reduce: tensor([1]).
Value of rank 2 copy of tensor after reduce: tensor([2]).
Value of rank 0 copy of tensor after reduce: tensor([3]).
Value of rank 2 copy of tensor before scatter: tensor([2]).
Value of rank 0 copy of tensor before scatter: tensor([0]).
Value of rank 1 copy of tensor after reduce: tensor([3]).
Value of rank 1 copy of tensor before scatter: tensor([1]).
Value of rank 2 copy of tensor after scatter: tensor([6]).
Value of rank 2 copy of gather_list before gather: None.
Value of rank 0 copy of tensor after scatter: tensor([4]).
Value of rank 0 copy of gather_list before gather: [tensor([9]), tensor([9]), tensor([9])].
Value of rank 1 copy of tensor after scatter: tensor([5]).
Value of rank 2 copy of gather_list after gather: None.
Value of rank 1 copy of gather_list before gather: None.
Value of rank 1 copy of gather_list after gather: None.
Value of rank 1 copy of gather_list before all_gather: [tensor([9]), tensor([9]), tensor([9])].
Value of rank 2 copy of gather_list before all_gather: [tensor([9]), tensor([9]), tensor([9])].
Value of rank 0 copy of gather_list after gather: [tensor([0]), tensor([1]), tensor([2])].
Value of rank 0 copy of gather_list before all_gather: [tensor([9]), tensor([9]), tensor([9])].
Value of rank 0 copy of gather_list after all_gather: [tensor([0]), tensor([1]), tensor([2])].
Done with rank 0 process!
Value of rank 1 copy of gather_list after all_gather: [tensor([0]), tensor([1]), tensor([2])].
Done with rank 1 process!
Value of rank 2 copy of gather_list after all_gather: [tensor([0]), tensor([1]), tensor([2])].
Done with rank 2 process!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment