Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
import torch
import torch.distributed as dist
import os
import torch.multiprocessing as mp
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from contextlib import nullcontext
class LitLinear(LightningModule):
def __init__(self, l1):
super().__init__()
self.l1 = l1
def forward(self, x):
return self.l1(x)
def training_step(self, batch, batch_idx):
return self(batch).sum()
def worker(rank):
dist.init_process_group("gloo", rank=rank, world_size=2)
print(f"Worker {rank}: done initializing PG.")
batch, dim = 20, 10
linear_layer = nn.Linear(dim, dim, bias=False)
lightning_ddp = pl.overrides.data_parallel.LightningDistributedDataParallel(LitLinear(linear_layer))
pt_ddp = torch.nn.parallel.DistributedDataParallel(linear_layer)
# Create uneven inputs, rank 1 will get one more input than rank 0. This will cause a hang without join() API.
inputs = [torch.rand(batch, dim).float() for _ in range(10 + rank)]
# Lightning training loop
with lightning_ddp.join():
for i in range(5):
for inp in inputs:
loss = lightning_ddp.forward(inp, 1).sum()
loss.backward()
# PT DDP train loop
with pt_ddp.join():
for i in range(5):
for inp in inputs:
loss = pt_ddp.forward(inp).sum()
loss.backward()
# Validate equivalence
lightning_state_dict = lightning_ddp.module.l1.state_dict()
ddp_state_dict = pt_ddp.module.state_dict()
assert all(torch.allclose(ddp_state_dict[k], lightning_state_dict[k]) for k in lightning_state_dict.keys())
if __name__ == '__main__':
print("--- starting join test ---")
os.environ["MASTER_ADDR"] = "localhost" ; os.environ["MASTER_PORT"] = "29501"
print("spawning workers....")
mp.spawn(worker, nprocs=2, args=())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment