Skip to content

Instantly share code, notes, and snippets.

@williamFalcon
Created July 29, 2019 01:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save williamFalcon/c9cf11a08ca73b421f7c019ec286e25c to your computer and use it in GitHub Desktop.
Save williamFalcon/c9cf11a08ca73b421f7c019ec286e25c to your computer and use it in GitHub Desktop.
PTL template
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import pytorch_lightning as ptl
class CoolModel(ptl.LightningModule):
def __init__(self):
super(CoolModel, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'loss': self.my_loss(y_hat, y)}
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]
@ptl.data_loader
def tng_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@ptl.data_loader
def val_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@ptl.data_loader
def test_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment