Skip to content

Instantly share code, notes, and snippets.

@BexTuychiev
Created July 6, 2024 10:44
Show Gist options
  • Save BexTuychiev/4fc7b5e11a51b0a1820646f236a7dffd to your computer and use it in GitHub Desktop.
Save BexTuychiev/4fc7b5e11a51b0a1820646f236a7dffd to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import lightning as L
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
# Define the CNN architecture
class CIFAR10CNN(L.LightningModule):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 64 * 4 * 4)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(1) == y).float().mean()
self.log("train_loss", loss)
self.log("train_acc", acc)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(1) == y).float().mean()
self.log("val_loss", loss)
self.log("val_acc", acc)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(1) == y).float().mean()
self.log("test_loss", loss)
self.log("test_acc", acc)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=5
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
},
}
# Data transformations
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
transform_test = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform_train
)
val_dataset = datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform_test
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
# Initialize the model
model = CIFAR10CNN()
# Define callbacks
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
monitor="val_loss",
filename="cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}",
save_top_k=3,
mode="min",
)
early_stopping = EarlyStopping(
monitor="val_loss", patience=5, mode="min", verbose=False
)
# Initialize the logger
logger = TensorBoardLogger(save_dir="lightning_logs", name="cifar10_cnn")
# Initialize the Trainer
trainer = L.Trainer(
max_epochs=50,
callbacks=[checkpoint_callback, early_stopping],
logger=logger,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices="auto",
)
# Train the model
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, val_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment