Last active
August 11, 2021 12:09
-
-
Save ResidentMario/dc542fc26a142a9dce85b258835c45ad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import os | |
import torch | |
import torch.distributed as dist | |
from torch.multiprocessing import Process | |
def run(rank, size): | |
""" Distributed function to be implemented later. """ | |
# collective ops are performed against groups | |
group = dist.new_group([0, 1, 2]) | |
# For demo purposes, tensor reflects process rank | |
tensor = torch.tensor([rank]) | |
# broadcast example | |
# | |
# broadcast synchronizes one process's copy of a tensor with all peer copies of that | |
# tensor. | |
# | |
# broadcast automatically barriers, e.g. it's guaranteed it will not return until all peer | |
# copies have recieved the update. | |
print(f"Value of rank {rank} copy of tensor before broadcast: {tensor}.\n") | |
dist.broadcast(tensor, 0, group) | |
print(f"Value of rank {rank} copy of tensor after broadcast: {tensor}.\n") | |
# all_reduce example | |
# | |
# all_reduce performs a global reduce op on tensor | |
tensor = torch.tensor([rank]) | |
print(f"Value of rank {rank} copy of tensor before all_reduce: {tensor}.\n") | |
dist.all_reduce(tensor, dist.ReduceOp.SUM, group) | |
# 0 + 1 + 2 = 3 | |
print(f"Value of rank {rank} copy of tensor after all_reduce: {tensor}.\n") | |
# reduce example | |
# | |
# reduce performs a local reduce op on a tensor, e.g., one which is only reflected in | |
# the process running the reduce. Which process sees the result is controlled by the | |
# addtl dest field. | |
tensor = torch.tensor([rank]) | |
dest = 0 | |
print(f"Value of rank {rank} copy of tensor before reduce: {tensor}.\n") | |
dist.reduce(tensor, dest, dist.ReduceOp.SUM, group) | |
print(f"Value of rank {rank} copy of tensor after reduce: {tensor}.\n") | |
# scatter example | |
# | |
# scatter distributes a list of values amongst the processes. | |
# | |
# some gotchas here are that scatter_list must be defined but None on all non-source | |
# ranks. that's not mentioned in the docs or docstring ffs! | |
tensor = torch.tensor([rank]) | |
if rank == 0: | |
scatter_list = [torch.tensor([4]), torch.tensor([5]), torch.tensor([6])] | |
else: | |
scatter_list = None | |
print(f"Value of rank {rank} copy of tensor before scatter: {tensor}.\n") | |
dist.scatter(tensor, scatter_list, 0, group) # 0 is the src, e.g. the seeder | |
print(f"Value of rank {rank} copy of tensor after scatter: {tensor}.\n") | |
# gather example | |
# | |
# similar rules to scatter | |
tensor = torch.tensor([rank]) | |
if rank == 0: | |
gather_list = [torch.tensor([9]), torch.tensor([9]), torch.tensor([9])] | |
else: | |
gather_list = None | |
print(f"Value of rank {rank} copy of gather_list before gather: {gather_list}.\n") | |
dist.gather(tensor, gather_list, 0, group) # 0 is the src, e.g. the seeder | |
print(f"Value of rank {rank} copy of gather_list after gather: {gather_list}.\n") | |
# last one, all-gather | |
tensor = torch.tensor([rank]) | |
gather_list = [torch.tensor([9]), torch.tensor([9]), torch.tensor([9])] | |
print(f"Value of rank {rank} copy of gather_list before all_gather: {gather_list}.\n") | |
dist.all_gather(gather_list, tensor, group) # 0 is the src, e.g. the seeder | |
print(f"Value of rank {rank} copy of gather_list after all_gather: {gather_list}.\n") | |
print(f"Done with rank {rank} process!\n") | |
def init_process(rank, size, fn, backend='gloo'): | |
""" Initialize the distributed environment. """ | |
os.environ['MASTER_ADDR'] = '127.0.0.1' | |
os.environ['MASTER_PORT'] = '29500' | |
dist.init_process_group(backend, rank=rank, world_size=size) | |
fn(rank, size) | |
if __name__ == "__main__": | |
size = 3 | |
processes = [] | |
for rank in range(size): | |
p = Process(target=init_process, args=(rank, size, run)) | |
p.start() | |
processes.append(p) | |
for p in processes: | |
p.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Value of rank 0 copy of tensor before broadcast: tensor([0]). | |
Value of rank 1 copy of tensor before broadcast: tensor([1]). | |
Value of rank 2 copy of tensor before broadcast: tensor([2]). | |
Value of rank 1 copy of tensor after broadcast: tensor([0]). | |
Value of rank 0 copy of tensor after broadcast: tensor([0]). | |
Value of rank 1 copy of tensor before all_reduce: tensor([1]). | |
Value of rank 0 copy of tensor before all_reduce: tensor([0]). | |
Value of rank 2 copy of tensor after broadcast: tensor([0]). | |
Value of rank 2 copy of tensor before all_reduce: tensor([2]). | |
Value of rank 2 copy of tensor after all_reduce: tensor([3]). | |
Value of rank 0 copy of tensor after all_reduce: tensor([3]). | |
Value of rank 0 copy of tensor before reduce: tensor([0]). | |
Value of rank 2 copy of tensor before reduce: tensor([2]). | |
Value of rank 1 copy of tensor after all_reduce: tensor([3]). | |
Value of rank 1 copy of tensor before reduce: tensor([1]). | |
Value of rank 2 copy of tensor after reduce: tensor([2]). | |
Value of rank 0 copy of tensor after reduce: tensor([3]). | |
Value of rank 2 copy of tensor before scatter: tensor([2]). | |
Value of rank 0 copy of tensor before scatter: tensor([0]). | |
Value of rank 1 copy of tensor after reduce: tensor([3]). | |
Value of rank 1 copy of tensor before scatter: tensor([1]). | |
Value of rank 2 copy of tensor after scatter: tensor([6]). | |
Value of rank 2 copy of gather_list before gather: None. | |
Value of rank 0 copy of tensor after scatter: tensor([4]). | |
Value of rank 0 copy of gather_list before gather: [tensor([9]), tensor([9]), tensor([9])]. | |
Value of rank 1 copy of tensor after scatter: tensor([5]). | |
Value of rank 2 copy of gather_list after gather: None. | |
Value of rank 1 copy of gather_list before gather: None. | |
Value of rank 1 copy of gather_list after gather: None. | |
Value of rank 1 copy of gather_list before all_gather: [tensor([9]), tensor([9]), tensor([9])]. | |
Value of rank 2 copy of gather_list before all_gather: [tensor([9]), tensor([9]), tensor([9])]. | |
Value of rank 0 copy of gather_list after gather: [tensor([0]), tensor([1]), tensor([2])]. | |
Value of rank 0 copy of gather_list before all_gather: [tensor([9]), tensor([9]), tensor([9])]. | |
Value of rank 0 copy of gather_list after all_gather: [tensor([0]), tensor([1]), tensor([2])]. | |
Done with rank 0 process! | |
Value of rank 1 copy of gather_list after all_gather: [tensor([0]), tensor([1]), tensor([2])]. | |
Done with rank 1 process! | |
Value of rank 2 copy of gather_list after all_gather: [tensor([0]), tensor([1]), tensor([2])]. | |
Done with rank 2 process! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment