Skip to content

Instantly share code, notes, and snippets.

@hotbaby
Created January 10, 2023 05:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hotbaby/4a8f637be3b262b4ac9fdceee5e66a0d to your computer and use it in GitHub Desktop.
Save hotbaby/4a8f637be3b262b4ac9fdceee5e66a0d to your computer and use it in GitHub Desktop.
PyTorch集合通信collective communication
# encoding: utf8
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def cm_broadcast_object_demo(rank: int, world_size: int):
dist.init_process_group("gloo", world_size=world_size, rank=rank)
if rank == 0:
objects = ["foo", 12, {"key": "value"}]
else:
objects = [None, None, None]
dist.broadcast_object_list(objects, src=0, device=torch.device("cpu"))
print(f"rank: {rank}, objects: {objects}")
def cm_broadcast_demo(rank: int, world_size: int):
"""Broadcast"""
dist.init_process_group("gloo", world_size=world_size, rank=rank)
if dist.get_rank() == 0:
tensor = torch.arange(10)
else:
tensor = torch.zeros(10, dtype=torch.int64)
dist.broadcast(tensor, src=0)
print(tensor)
def cm_allreduce_demo(rank: int, world_size: int):
"""ring allreduce"""
dist.init_process_group("gloo", world_size=world_size, rank=rank)
rank = dist.get_rank()
tensor = torch.arange(2) + 2 * rank
print(f"before allreduce, rank: {rank}, tensor: {tensor}")
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
print(f"after allreduce, rank: {rank}, tensor: {tensor}")
def cm_scatter_demo(rank: int, world_size: int):
dist.init_process_group("gloo", world_size=world_size, rank=rank)
scatter_list = [torch.ones(2), torch.ones(2) * 3]
output_tensor = torch.zeros_like(scatter_list[0])
dist.scatter(
output_tensor,
scatter_list if rank == 0 else None,
src=0
)
print(f"scatter_list: {scatter_list}, rank: {rank}, output: {output_tensor}")
def cm_gather_demo(rank: int, world_size: int):
dist.init_process_group("gloo", world_size=world_size, rank=rank)
rank = dist.get_rank()
tensor = torch.tensor([rank])
print(f"before gather, rank: {rank}, tensor: {tensor}")
output = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.gather(
tensor,
output if rank == 0 else None, # Argument ``gather_list`` must NOT be specified on non-destination ranks.
dst=0)
if rank == 0:
concat_output = torch.concat(output)
print(f"after gather, rank: {rank}, output: {output}, concat output: {concat_output}")
collection_methods = [
cm_broadcast_object_demo,
cm_broadcast_demo,
cm_allreduce_demo,
cm_gather_demo,
cm_scatter_demo,
]
def main():
world_size = 2
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
for method in collection_methods:
print(f"collective communication {method.__name__}")
mp.spawn(method, (world_size,), nprocs=world_size, join=True)
print("")
if __name__ == "__main__":
main()
@hotbaby
Copy link
Author

hotbaby commented Jan 10, 2023

PyTorch集合通信方法:

  • all_reduce()
  • broadcast()
  • gather()
  • scatter()

@hotbaby
Copy link
Author

hotbaby commented Jan 10, 2023

@hotbaby
Copy link
Author

hotbaby commented Jan 10, 2023

allreduce

broadcast

gather

scatter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment