Skip to content

Instantly share code, notes, and snippets.

@nilansaha
Created November 29, 2020 07:14
Show Gist options
  • Save nilansaha/016242a2724ccab4d21c3d70dc6f825a to your computer and use it in GitHub Desktop.
Save nilansaha/016242a2724ccab4d21c3d70dc6f825a to your computer and use it in GitHub Desktop.
PyTorch Lightning Boilerplate
import torch
import random
import torch.nn as nn
from torch.optim import Adam
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
class SampleDataset(Dataset):
def __init__(self, size, coef_1, coef_2):
self.size = size
self.coef_1 = coef_1
self.coef_2 = coef_2
def __len__(self):
return self.size
def __getitem__(self, idx):
n1 = random.randint(1, 1000)
n2 = random.randint(1, 1000)
feature = torch.Tensor([n1, n2])
output = torch.Tensor([n1*self.coef_1 + n2*self.coef_2])
return feature, output
class PLNet(pl.LightningModule):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 1)
def forward(self, x):
out = self.fc1(x)
return out
def training_step(self, batch, batch_idx):
features, output = batch
predicted = self(features)
loss = nn.MSELoss()(predicted, output)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
def train_dataloader(self):
return DataLoader(SampleDataset(size=10000, coef_1=3, coef_2=4), batch_size=32, num_workers=4)
model = PLNet()
trainer = pl.Trainer(max_epochs=30, progress_bar_refresh_rate=50)
trainer.fit(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment