Created
December 15, 2020 05:02
-
-
Save rohan-varma/3906e7f07669f0177801a9f753848550 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed as dist | |
import os | |
import torch.multiprocessing as mp | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
from pytorch_lightning.core.lightning import LightningModule | |
from contextlib import nullcontext | |
class LitLinear(LightningModule): | |
def __init__(self, l1): | |
super().__init__() | |
self.l1 = l1 | |
def forward(self, x): | |
return self.l1(x) | |
def training_step(self, batch, batch_idx): | |
return self(batch).sum() | |
def worker(rank): | |
dist.init_process_group("gloo", rank=rank, world_size=2) | |
print(f"Worker {rank}: done initializing PG.") | |
batch, dim = 20, 10 | |
linear_layer = nn.Linear(dim, dim, bias=False) | |
lightning_ddp = pl.overrides.data_parallel.LightningDistributedDataParallel(LitLinear(linear_layer)) | |
pt_ddp = torch.nn.parallel.DistributedDataParallel(linear_layer) | |
# Create uneven inputs, rank 1 will get one more input than rank 0. This will cause a hang without join() API. | |
inputs = [torch.rand(batch, dim).float() for _ in range(10 + rank)] | |
# Lightning training loop | |
with lightning_ddp.join(): | |
for i in range(5): | |
for inp in inputs: | |
loss = lightning_ddp.forward(inp, 1).sum() | |
loss.backward() | |
# PT DDP train loop | |
with pt_ddp.join(): | |
for i in range(5): | |
for inp in inputs: | |
loss = pt_ddp.forward(inp).sum() | |
loss.backward() | |
# Validate equivalence | |
lightning_state_dict = lightning_ddp.module.l1.state_dict() | |
ddp_state_dict = pt_ddp.module.state_dict() | |
assert all(torch.allclose(ddp_state_dict[k], lightning_state_dict[k]) for k in lightning_state_dict.keys()) | |
if __name__ == '__main__': | |
print("--- starting join test ---") | |
os.environ["MASTER_ADDR"] = "localhost" ; os.environ["MASTER_PORT"] = "29501" | |
print("spawning workers....") | |
mp.spawn(worker, nprocs=2, args=()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment