Skip to content

Instantly share code, notes, and snippets.

@hotbaby
Created August 14, 2023 06:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hotbaby/4b14881f04a8968281bc7988e6b842b7 to your computer and use it in GitHub Desktop.
Save hotbaby/4b14881f04a8968281bc7988e6b842b7 to your computer and use it in GitHub Desktop.
# encoding: utf8
import logging
import torch
import torch.distributed
from torch.distributed import ReduceOp
def print_rank_0(msg, *args, **kwargs):
rank = torch.distributed.get_rank()
if rank == 0:
logging.info(msg, *args, **kwargs)
def dist_allgather():
print_rank_0("allgather:")
torch.distributed.barrier()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
input_tensor = torch.tensor(rank)
tensor_list = [torch.zeros(1, dtype=torch.int64) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, input_tensor)
logging.info(f"allgather, rank: {rank}, input_tensor: {repr(input_tensor)}, output tensor_list: {tensor_list}")
torch.distributed.barrier()
def dist_allreduce():
print_rank_0("all_reduce:")
torch.distributed.barrier()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
tensor = torch.tensor(rank)
input_tensor = tensor.clone()
torch.distributed.all_reduce(tensor)
logging.info(f"all_reduce, rank: {rank}, before allreduce tensor: {repr(input_tensor)}, after allreduce tensor: {repr(tensor)}")
torch.distributed.barrier()
def dist_reducescatter():
print_rank_0("reduce_scatter:")
torch.distributed.barrier()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
output = torch.empty(1, dtype=torch.int64)
input_list = [torch.tensor(rank) for i in range(world_size)]
torch.distributed.reduce_scatter(output, input_list, op=ReduceOp.SUM)
torch.distributed.barrier()
logging.info(f"reduce_scatter, rank: {rank}, input_list: {input_list}, tensor: {repr(output)}")
torch.distributed.barrier()
def dist_broadcast():
print_rank_0("broadcast:")
torch.distributed.barrier()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
src_rank = 2
tensor = torch.tensor(world_size) if rank == src_rank else torch.zeros(1, dtype=torch.int64)
before_tensor = tensor.clone()
torch.distributed.broadcast(tensor, src=src_rank)
logging.info(f"broadcast, rank: {rank}, before broadcast tensor: {repr(before_tensor)} after broadcast tensor: {repr(tensor)}")
torch.distributed.barrier()
def main():
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
local_rank = rank % torch.cuda.device_count()
torch.set_default_device(f"cuda:{local_rank}")
dist_reducescatter()
dist_allreduce()
dist_allgather()
dist_broadcast()
if __name__ == "__main__":
logging.basicConfig(format=logging.BASIC_FORMAT, level=logging.INFO)
main()
@hotbaby
Copy link
Author

hotbaby commented Aug 14, 2023

AllReduce

ReduceScatter

AllGather

Broadcast

@hotbaby
Copy link
Author

hotbaby commented Aug 14, 2023

如何运行NCCL测试程序?

deepspeed --num_gpus 4 --num_nodes 1 torch_nccl_test.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment