https://blog.csdn.net/weixin_39718268/article/details/105021631
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main.py [--arg1 --arg2 ...]
# 1) 初始化
torch.distributed.init_process_group(backend="nccl")
# 2) 配置每个进程的gpu
local_rank = torch.distributed.get_rank()
print('local_rank:{}'.format(local_rank))
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
# 4) 封装之前要把模型移到对应的gpu
net = net.to(device)
# 第五步不需要,会报错
# if torch.cuda.device_count() > 1:
# # print("Let's use", torch.cuda.device_count(), "GPUs!")
# # # 就这一行
# # net = nn.DataParallel(net)
# print("Let's use", torch.cuda.device_count(), "GPUs!")
# # 5) 封装
# model = torch.nn.parallel.DistributedDataParallel(net,
# device_ids=[local_rank],
# output_device=local_rank)
# 6) 数据打乱shuffle
for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)