Skip to content

Instantly share code, notes, and snippets.

@jw3126
Last active November 20, 2020 13:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jw3126/29cbb8177a9f4acfb66fc3a3ed7cbf24 to your computer and use it in GitHub Desktop.
Save jw3126/29cbb8177a9f4acfb66fc3a3ed7cbf24 to your computer and use it in GitHub Desktop.
pytorch_lightning_ddp_gradient_checkpointing_bug
# This reproduces a pytorch_lightning issue
# where gradient checkpointing + ddp results in nan loss
#
# * Run with gpus=1 and it works fine.
# * Run with gpus=4 and it loss becomes nan quickly
#
# See also https://forums.pytorchlightning.ai/t/gradient-checkpointing-ddp-nan/398
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
class RandomDataset(Dataset):
def __init__(self, size, num_samples):
self.len = num_samples
self.data = torch.randn(num_samples, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class MergeLayer(torch.nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.layer = torch.nn.Linear(in_size, out_size)
def apply_forward(self, xs):
# pdb.set_trace()
y = torch.cat(xs,dim=1)
return self.layer(y)
def _apply_forward_splat(self, *xs):
return self.apply_forward(xs)
def forward(self, xs):
requires_grad = False
for x in xs:
if x.requires_grad:
requires_grad = True
if requires_grad:
return torch.utils.checkpoint.checkpoint(self._apply_forward_splat, *xs)
else:
return self.apply_forward(xs)
class BoringModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.merge = MergeLayer(64, 32)
def forward(self, x):
x1 = F.leaky_relu(self.layer1(x))
# x2 = F.leaky_relu(self.layer2(x))
xs = [x, x1, ]
return self.merge(xs)
def training_step(self, batch, batch_idx):
output = self.forward(batch)
loss = torch.nn.functional.mse_loss(output, batch)
return {"loss": loss}
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
return optimizer
num_samples = 10000
train = RandomDataset(32, num_samples)
train = DataLoader(train, batch_size=32)
model = BoringModel()
# Initialize a trainer
trainer = pl.Trainer(
max_epochs=10,
progress_bar_refresh_rate=20,
accelerator="ddp",
gpus=4, # nan loss
# gpus=1, #works
)
# Train the model ⚡
trainer.fit(model, train)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment