Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created April 5, 2023 20:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save richardliaw/a8f5250597a6642f271f31e642896119 to your computer and use it in GitHub Desktop.
Save richardliaw/a8f5250597a6642f271f31e642896119 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from ray.air.config import ScalingConfig
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
class MNISTClassifier(pl.LightningModule):
def __init__(self, lr, feature_dim):
super(MNISTClassifier, self).__init__()
self.fc1 = torch.nn.Linear(28 * 28, feature_dim)
self.fc2 = torch.nn.Linear(feature_dim, 10)
self.lr = lr
self.accuracy = Accuracy()
def forward(self, x):
x = x.view(-1, 28 * 28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
# Prepare MNIST Datasets
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
mnist_train = MNIST(
'./data', train=True, download=True, transform=transform
)
mnist_val = MNIST(
'./data', train=False, download=True, transform=transform
)
# Take small subsets for smoke test
# Please remove these two lines if you want to train the full dataset
mnist_train = Subset(mnist_train, range(1000))
mnist_train = Subset(mnist_train, range(500))
train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
val_loader = DataLoader(mnist_val, batch_size=128, shuffle=False)
lightning_config = (
LightningConfigBuilder()
.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)
.trainer(max_epochs=3, accelerator="cpu")
.fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader)
.build()
)
scaling_config = ScalingConfig(
num_workers=4, use_gpu=False, resources_per_worker={"CPU": 1}
)
trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
)
result = trainer.fit()
result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment