PyTorch Lightning CLI not picking up docstring arg description to make help text
#!/bin/env python3
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CyclicLR
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.cli import LightningCLI
from typing import Optional
import pdb
class MnistDataModule(LightningDataModule):
def __init__(self, batch_size: int=64, test_batch_size: int=1000):
batch_size: training batch size
test_batch_size: testing batch size
self._batch_size = batch_size
self._test_batch_size = test_batch_size
def prepare_data(self):
datasets.MNIST('data', train=True, download=True)
datasets.MNIST('data', train=False)
def setup(self, stage: Optional[TrainerFn] = None):
print('SETUP', stage)
transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
kwargs = {'num_workers': 1, 'pin_memory': True}
if stage == TrainerFn.FITTING or stage is None:
kwargs['batch_size'] = self._batch_size
kwargs['shuffle'] = True
self._train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
if stage == TrainerFn.TESTING or stage is None:
kwargs['batch_size'] = self._test_batch_size
kwargs['shuffle'] = False
self._test_dataset = datasets.MNIST('data', train=False, transform=transform)
def train_dataloader(self):
kwargs = {'batch_size': self._batch_size,
'num_workers': 1, 'pin_memory': True, 'shuffle': True}
train_loader =, **kwargs)
return train_loader
def test_dataloader(self):
kwargs = {'batch_size': self._test_batch_size,
'num_workers': 1, 'pin_memory': True, 'shuffle': True}
test_loader =, **kwargs)
return test_loader
class MnistModule(LightningModule):
def __init__(self, lr: float):
"""Test MNIST model
lr: min learning rate
self.example_input_array = torch.Tensor(3, 1, 28, 28) = lr
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def training_step(self, batch, batch_idx):
x, y = batch
loss = F.nll_loss(self(x), y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
num_samples = y.size(0)
output = self(x)
loss = F.nll_loss(output, y, reduction='sum')
pred = output.argmax(dim=1, keepdim=False) # get the index of the max log-probability
correct = pred.eq(y).sum()
self.log_dict({'val_loss_sum1': loss,
'val_loss1': loss / num_samples,
'val_accuracy1': correct / num_samples},
return {'n': num_samples, 'l': loss, 'c': correct}
def validation_epoch_end(self, outputs) -> None:
n = sum(x['n'] for x in outputs)
l = sum(x['l'] for x in outputs)
c = sum(x['c'] for x in outputs)
loss = l / n
accuracy = c / n
self.log_dict({'val_loss2': loss,
'val_accuracy2': accuracy,
'val_n': float(n)})
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def test_epoch_end(self, outputs) -> None:
return self.validation_epoch_end(outputs)
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(),, momentum=0.9)
scheduler = CyclicLR(optimizer,, * 10)
return [optimizer], [scheduler]
def cli_main():
cli = LightningCLI(
run=False, # Disable automatic fitting.
trainer_defaults={'max_epochs': 3, 'accelerator': 'gpu', 'devices': 1},
print('HPARAMS', cli.model.hparams), datamodule=cli.datamodule)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
if __name__ == '__main__':
# main()
