Skip to content

Instantly share code, notes, and snippets.

@ruotianluo
Created July 12, 2020 20:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ruotianluo/94f318cbf2612850ba12f6e7ad7b82f9 to your computer and use it in GitHub Desktop.
Save ruotianluo/94f318cbf2612850ba12f6e7ad7b82f9 to your computer and use it in GitHub Desktop.
import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 5, bias=False)
self.bias = nn.Parameter(torch.zeros(5))
self.net1.bias = self.bias
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(x)
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
for i in range(10):
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
run_demo(demo_basic, 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment