Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Created September 1, 2020 19:57
Show Gist options
  • Save rohan-varma/ae83a1d3c67440c910222a664576e840 to your computer and use it in GitHub Desktop.
Save rohan-varma/ae83a1d3c67440c910222a664576e840 to your computer and use it in GitHub Desktop.
import os
import torch
import torch.nn.parallel.distributed as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, rank):
super().__init__()
self.size = 10 + 3 * rank
self.idx = 0
def __iter__(self):
return self
def __next__(self):
if self.idx == self.size:
raise StopIteration
ret = torch.randn(1)
self.idx+=1
return ret
def worker(i):
dist.init_process_group("gloo", rank=i, world_size=2)
torch.cuda.set_device(i)
dataset = MyIterableDataset(rank=i)
model = nn.Linear(1, 1, bias=False).to(i)
model = DDP(model, device_ids=[i], output_device=[i])
with model.join():
for batch in dataset:
model(batch)
# Without support for uneven inputs, following would hang
# due to mismatch no. of allreduces.
torch.cuda.synchronize(device=i)
if __name__ == '__main__':
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(worker, nprocs=2, args=())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment