Skip to content

Instantly share code, notes, and snippets.

@alecmerdler
Created April 5, 2023 11:06
Show Gist options
  • Save alecmerdler/32fdb95d2a1d803248245482d0d7d386 to your computer and use it in GitHub Desktop.
Save alecmerdler/32fdb95d2a1d803248245482d0d7d386 to your computer and use it in GitHub Desktop.
Lightning Autoencoder
import os
from argparse import ArgumentParser
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision as tv
import torch.nn.functional as F
import lightning as L
parser = ArgumentParser()
parser.add_argument('--lr', default=0.02, type=float)
args = parser.parse_args()
# --------------------------------
# Step 1: Define a LightningModule
# --------------------------------
# A LightningModule (nn.Module subclass) defines a full *system*
# (ie: an LLM, difussion model, autoencoder, or simple image classifier).
class LitAutoEncoder(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
)
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defines the train loop. It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
print('lr: ', args.lr)
optimizer = torch.optim.Adam(self.parameters(), lr=args.lr)
return optimizer
# -------------------
# Step 2: Define data
# -------------------
dataset = tv.datasets.MNIST(os.getcwd(), download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])
# -------------------
# Step 3: Train
# -------------------
autoencoder = LitAutoEncoder()
trainer = L.Trainer()
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment