Skip to content

Instantly share code, notes, and snippets.

@phuocphn
Created February 29, 2020 12:37
Show Gist options
  • Save phuocphn/2e115f8f52f2f8b77df8bb3c9afd89ac to your computer and use it in GitHub Desktop.
Save phuocphn/2e115f8f52f2f8b77df8bb3c9afd89ac to your computer and use it in GitHub Desktop.
distributed_relu_pytorch.py
import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import numpy as np
import random
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
pid = os.getpid()
seed = 1000
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12356'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# Explicitly setting seed to make sure that models created in two processes
# start from same random weights and biases.
torch.manual_seed(42)
def cleanup():
dist.destroy_process_group()
class CustomRELU(nn.Module):
def __init__(self, inplace=False):
super(CustomRELU, self).__init__()
self.inplace = inplace
def forward(self, input):
result = torch.zeros_like(input)
result = F.relu(input, inplace=self.inplace)
group = dist.new_group([0, 1])
dist.all_reduce(result, op=dist.reduce_op.SUM, group=group)
return result
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu1 = CustomRELU()
self.net2 = nn.Linear(10, 5)
self.relu2 = CustomRELU()
self.net3 = nn.Linear(5, 5)
def forward(self, x):
x = self.net1(x)
__bak_x = nn.ReLU()(x)
x = self.relu1(x)
print ("[{}] RELU1: {}".format(pid, str(x[0])))
x = self.net2(x)
__x2 = self.net2(__bak_x)
x = self.relu2(__x2)
print ("[{}] RELU2: {}".format(pid, str(x[0])))
x = self.net3(x)
return x
def demo_basic(rank, world_size):
setup(rank, world_size)
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
# rank 2 uses GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))
print ("*" * 10)
print ("[{}] rank: {}, world_size: {}".format(pid, rank, world_size))
print ("[{}] device_ids: {}".format(pid, device_ids))
print ("*" * 10)
# create model and move it to device_ids[0]
model = ToyModel().to(device_ids[0])
# output_device defaults to device_ids[0]
ddp_model = DDP(model, device_ids=device_ids)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.ones(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
run_demo(demo_basic, 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment