Skip to content

Instantly share code, notes, and snippets.

@colesbury
Created July 30, 2020 23:58
Show Gist options
  • Save colesbury/b275b3c15ea8ced8ce0a4e16aa7c4891 to your computer and use it in GitHub Desktop.
Save colesbury/b275b3c15ea8ced8ce0a4e16aa7c4891 to your computer and use it in GitHub Desktop.
import random
import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
N = 200
class MyModel(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.embed = nn.Embedding(N, in_ch * out_ch, sparse=True)
def forward(self, x):
idx = random.randint(0, self.embed.num_embeddings - 1)
idx = torch.as_tensor(idx).to(self.embed.weight.device)
weight = self.embed(idx)
weight = weight.reshape(self.in_ch, self.out_ch)
return x @ weight.t()
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def demo_ddp(rank, world_size):
# setup mp_model and devices for this process
setup(rank, world_size)
# create model and move it to GPU with id rank
model = MyModel(100, 100).to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
for iter in range(1000):
print(iter)
x = torch.randn(100, device=rank)
labels = torch.randn(100, device=rank)
optimizer.zero_grad()
outputs = ddp_model(x)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
if __name__ == '__main__':
world_size = 2
mp.spawn(demo_ddp,
args=(world_size,),
nprocs=world_size,
join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment