Skip to content

Instantly share code, notes, and snippets.

@RameshKamath
Created April 26, 2021 06:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RameshKamath/fa7e118f30e8f386475ad240bc49988d to your computer and use it in GitHub Desktop.
Save RameshKamath/fa7e118f30e8f386475ad240bc49988d to your computer and use it in GitHub Desktop.
Distributed Data Parallel for Training Torch Model Parallely
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(1,1)
self.relu = nn.ReLU()
self.net2 = nn.Linear(1,1)
def forward(self,x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
for i in range(100):
input = torch.randn(2,1).to(rank)
outputs = ddp_model(input)
print("iter:{}, rank:{}, data:{}".format(i,rank,input))
labels = torch.randn(1,1).to(rank)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Done.")
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
print("Starting.")
n_gpus = torch.cuda.device_count()
run_demo(demo_basic, n_gpus)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment