pytorch_lightning_ddp_gradient_checkpointing_bug
# This reproduces a pytorch_lightning issue | |
# where gradient checkpointing + ddp results in nan loss | |
# | |
# * Run with gpus=1 and it works fine. | |
# * Run with gpus=4 and it loss becomes nan quickly | |
# | |
# See also https://forums.pytorchlightning.ai/t/gradient-checkpointing-ddp-nan/398 | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader, random_split, Dataset | |
from torchvision.datasets import MNIST | |
from torchvision import transforms | |
import pytorch_lightning as pl | |
from pytorch_lightning.metrics.functional import accuracy | |
class RandomDataset(Dataset): | |
def __init__(self, size, num_samples): | |
self.len = num_samples | |
self.data = torch.randn(num_samples, size) | |
def __getitem__(self, index): | |
return self.data[index] | |
def __len__(self): | |
return self.len | |
class MergeLayer(torch.nn.Module): | |
def __init__(self, in_size, out_size): | |
super().__init__() | |
self.layer = torch.nn.Linear(in_size, out_size) | |
def apply_forward(self, xs): | |
# pdb.set_trace() | |
y = torch.cat(xs,dim=1) | |
return self.layer(y) | |
def _apply_forward_splat(self, *xs): | |
return self.apply_forward(xs) | |
def forward(self, xs): | |
requires_grad = False | |
for x in xs: | |
if x.requires_grad: | |
requires_grad = True | |
if requires_grad: | |
return torch.utils.checkpoint.checkpoint(self._apply_forward_splat, *xs) | |
else: | |
return self.apply_forward(xs) | |
class BoringModel(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.layer1 = torch.nn.Linear(32, 32) | |
self.merge = MergeLayer(64, 32) | |
def forward(self, x): | |
x1 = F.leaky_relu(self.layer1(x)) | |
# x2 = F.leaky_relu(self.layer2(x)) | |
xs = [x, x1, ] | |
return self.merge(xs) | |
def training_step(self, batch, batch_idx): | |
output = self.forward(batch) | |
loss = torch.nn.functional.mse_loss(output, batch) | |
return {"loss": loss} | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-2) | |
return optimizer | |
num_samples = 10000 | |
train = RandomDataset(32, num_samples) | |
train = DataLoader(train, batch_size=32) | |
model = BoringModel() | |
# Initialize a trainer | |
trainer = pl.Trainer( | |
max_epochs=10, | |
progress_bar_refresh_rate=20, | |
accelerator="ddp", | |
gpus=4, # nan loss | |
# gpus=1, #works | |
) | |
# Train the model ⚡ | |
trainer.fit(model, train) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment