Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Created December 15, 2020 05:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-varma/3906e7f07669f0177801a9f753848550 to your computer and use it in GitHub Desktop.
Save rohan-varma/3906e7f07669f0177801a9f753848550 to your computer and use it in GitHub Desktop.
import torch
import torch.distributed as dist
import os
import torch.multiprocessing as mp
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from contextlib import nullcontext
class LitLinear(LightningModule):
def __init__(self, l1):
super().__init__()
self.l1 = l1
def forward(self, x):
return self.l1(x)
def training_step(self, batch, batch_idx):
return self(batch).sum()
def worker(rank):
dist.init_process_group("gloo", rank=rank, world_size=2)
print(f"Worker {rank}: done initializing PG.")
batch, dim = 20, 10
linear_layer = nn.Linear(dim, dim, bias=False)
lightning_ddp = pl.overrides.data_parallel.LightningDistributedDataParallel(LitLinear(linear_layer))
pt_ddp = torch.nn.parallel.DistributedDataParallel(linear_layer)
# Create uneven inputs, rank 1 will get one more input than rank 0. This will cause a hang without join() API.
inputs = [torch.rand(batch, dim).float() for _ in range(10 + rank)]
# Lightning training loop
with lightning_ddp.join():
for i in range(5):
for inp in inputs:
loss = lightning_ddp.forward(inp, 1).sum()
loss.backward()
# PT DDP train loop
with pt_ddp.join():
for i in range(5):
for inp in inputs:
loss = pt_ddp.forward(inp).sum()
loss.backward()
# Validate equivalence
lightning_state_dict = lightning_ddp.module.l1.state_dict()
ddp_state_dict = pt_ddp.module.state_dict()
assert all(torch.allclose(ddp_state_dict[k], lightning_state_dict[k]) for k in lightning_state_dict.keys())
if __name__ == '__main__':
print("--- starting join test ---")
os.environ["MASTER_ADDR"] = "localhost" ; os.environ["MASTER_PORT"] = "29501"
print("spawning workers....")
mp.spawn(worker, nprocs=2, args=())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment