Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Created October 13, 2020 16:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-varma/718da672c0b94f03b712b26f2a709d73 to your computer and use it in GitHub Desktop.
Save rohan-varma/718da672c0b94f03b712b26f2a709d73 to your computer and use it in GitHub Desktop.
import torch
import torch.distributed as dist
import os
import torch.multiprocessing as mp
import torch.nn as nn
import contextlib
class enc(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Linear(10, 10, bias=False)
def forward(self, x):
return self.emb(x)
class dec(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Linear(1, 1, bias=False)
def forward(self, x):
return self.emb(x)
def worker(rank):
dist.init_process_group("nccl", rank=rank, world_size=2)
torch.cuda.set_device(rank)
e = enc().cuda(rank)
d = dec().cuda(rank)
# Share parameters
d.emb.weight = e.emb.weight
# Wrap in DDP
e = torch.nn.parallel.DistributedDataParallel(e, device_ids=[rank])
d = torch.nn.parallel.DistributedDataParallel(d, device_ids=[rank])
inp = torch.randn(1, 10, device=rank)
for _ in range(6):
encoded = e(inp)
decoded = d(encoded)
loss = decoded.sum()
loss.backward()
torch.cuda.synchronize(device=rank)
if __name__ == '__main__':
os.environ["MASTER_ADDR"] = "localhost" ; os.environ["MASTER_PORT"] = "29501"
mp.spawn(worker, nprocs=2, args=())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment