Skip to content

Instantly share code, notes, and snippets.

Created October 7, 2020 05:01
Show Gist options
  • Save nrupatunga/f9b4d3ad557b79cd48353c715b62ef62 to your computer and use it in GitHub Desktop.
Save nrupatunga/f9b4d3ad557b79cd48353c715b62ef62 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from import random_split, DataLoader
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
# We take in input dimensions as parameters and use those to dynamically build model.
self.channels = channels
self.width = width
self.height = height
self.num_classes = num_classes
self.hidden_size = hidden_size
self.learning_rate = learning_rate
self.model = nn.Sequential(
nn.Linear(channels * width * height, hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, num_classes))
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
print(f'Current epoch: {self.trainer.current_epoch}')
if self.current_epoch > 2:
return DataLoader(self.mnist_train, batch_size=32)
return DataLoader(self.mnist_train, batch_size=32)
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1, reload_dataloaders_every_epoch=True)
# Pass the datamodule as arg to to override model hooks :), dm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment