Skip to content

Instantly share code, notes, and snippets.

@chiragraman
Created October 2, 2020 10:17
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chiragraman/16b1a89787df0c517b8dfffae5c3d591 to your computer and use it in GitHub Desktop.
Save chiragraman/16b1a89787df0c517b8dfffae5c3d591 to your computer and use it in GitHub Desktop.
Minimal Example for bug report for Pytorch-Lightning
import argparse
from pathlib import Path
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TestTubeLogger
def init_torch(seed: int) -> None:
""" Initialise torch with a seed """
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = torch.nn.Linear(240, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
self.log('train_loss', loss, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
losses = [40, 20, 30, 10, 1, 0.9, 1, 1, 90, 100]
loss = torch.tensor(float(losses[self.current_epoch])).to(batch[0].device)
logs = {"val_loss": loss}
self.log_dict(logs)
class LitDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
return torch.rand(10, 3, 8), torch.rand(10)
def __len__(self) -> int:
return 50000000
class LitDataModule(pl.LightningDataModule):
def __init__(self, dataset: Dataset, batch_size: int):
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
def train_dataloader(self):
return DataLoader(self.dataset, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.dataset, batch_size=self.batch_size)
def main() -> None:
""" Run the main experiment """
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("-s", "--seed", type=int, default=1234,
help="seed for initializing pytorch")
parser.add_argument("--out_dir", type=str, help="root output directory")
# Add Trainer args
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Initialize pytorch
init_torch(args.seed)
seed_everything(args.seed)
# Initialize data module
dataset = LitDataset()
dm = LitDataModule(dataset, 16)
# Create experiment and fit
outroot = Path(args.out_dir)
logger = TestTubeLogger(save_dir=str(outroot / "logs"))
ckpt_filepath = "{}/{{epoch}}-{{val_loss:.2f}}".format(
str(outroot / "logs" / "checkpoints"))
checkpoint_callback = ModelCheckpoint(
filepath=ckpt_filepath, save_top_k=1, monitor="val_loss",
verbose=True
)
early_stop = EarlyStopping(monitor="val_loss", verbose=True)
trainer = Trainer.from_argparse_args(
args, logger=logger, checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop
)
model = LitModel()
trainer.fit(model, datamodule=dm)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment