Skip to content

Instantly share code, notes, and snippets.

@hotbaby
Last active January 11, 2023 03:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hotbaby/15950bbb43d052cd835b0f18c997f67c to your computer and use it in GitHub Desktop.
Save hotbaby/15950bbb43d052cd835b0f18c997f67c to your computer and use it in GitHub Desktop.
PyTorch分布式训练DDP Demo
# encoding: utf8
import os
import time
import random
import contextlib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
def transform(tensors: list):
return torch.stack([t[0] for t in tensors])
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(10, 100)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(100, 20)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
return self.fc2(x)
def ddp_demo(rank, world_size, accum_grad=4):
assert dist.is_gloo_available(), "Gloo is not available!"
print(f"world_size: {world_size}, rank: {rank}, is_gloo_available: {dist.is_gloo_available()}")
# 1. 初始化进程组
dist.init_process_group("gloo", world_size=world_size, rank=rank)
# model = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 20))
model = Model()
# 2. 分布式数据并行封装模型
ddp_model = DistributedDataParallel(model)
criterion = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=1e-3)
dataset = TensorDataset(torch.randn(1000, 10))
# 3. 数据并行(内部根据rank采样)
sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, shuffle=True)
dataloader = DataLoader(dataset=dataset, batch_size=24, sampler=sampler, collate_fn=transform)
for epoch in range(1):
for step, batch in enumerate(dataloader):
output = ddp_model(batch)
label = torch.rand_like(output)
if step % accum_grad == 0:
# 同步参数
context = contextlib.nullcontext
else:
# 4. 梯度累计,不同步参数
context = ddp_model.no_sync
with context():
time.sleep(random.random())
loss = criterion(output, label)
loss.backward()
if step % accum_grad == 0:
optimizer.step()
optimizer.zero_grad()
print(f"epoch: {epoch}, step: {step}, rank: {rank} update parameters.")
# 5. 销毁进程组上下文数据(一些全局变量)
dist.destroy_process_group()
def main():
world_size = 8
mp.spawn(ddp_demo, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
main()
@hotbaby
Copy link
Author

hotbaby commented Jan 6, 2023

def transform(tensors: list):
    return torch.stack([t[0] for t in tensors])

# 构建dataset和dataloader
dataset = TensorDataset(torch.randn(1000, 10))
# 根据rank采样
sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, shuffle=True)
dataloader = DataLoader(dataset=dataset, batch_size=24, sampler=sampler, collate_fn=transform)

collate_fn函数参数用于将sample合并成mini-batch

@hotbaby
Copy link
Author

hotbaby commented Jan 6, 2023

模型转换成分布式训练的步骤:

  1. 初始化进程组dist.init_process_group
  2. 分布式数据并行封装模型DistributedDataParallel(model)
  3. 数据分布式并行,将数据分成world_size 份,根据rank采样DistributedSampler(dataset=dataset, num_replicas=world_size, shuffle=True)
  4. 训练过程中梯度累计,降低训练进程间的参数同步频率,提升通信效率【可选】;
  5. 销毁进程组dist.destroy_process_group()

@hotbaby
Copy link
Author

hotbaby commented Jan 6, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment