Created
July 6, 2024 10:44
-
-
Save BexTuychiev/4fc7b5e11a51b0a1820646f236a7dffd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import lightning as L | |
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms | |
from lightning.pytorch.callbacks import ModelCheckpoint | |
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger | |
from lightning.pytorch.callbacks.early_stopping import EarlyStopping | |
# Define the CNN architecture | |
class CIFAR10CNN(L.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(3, 32, 3, padding=1) | |
self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
self.conv3 = nn.Conv2d(64, 64, 3, padding=1) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.fc1 = nn.Linear(64 * 4 * 4, 512) | |
self.fc2 = nn.Linear(512, 10) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = self.pool(F.relu(self.conv3(x))) | |
x = x.view(-1, 64 * 4 * 4) | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = F.cross_entropy(y_hat, y) | |
acc = (y_hat.argmax(1) == y).float().mean() | |
self.log("train_loss", loss) | |
self.log("train_acc", acc) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = F.cross_entropy(y_hat, y) | |
acc = (y_hat.argmax(1) == y).float().mean() | |
self.log("val_loss", loss) | |
self.log("val_acc", acc) | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = F.cross_entropy(y_hat, y) | |
acc = (y_hat.argmax(1) == y).float().mean() | |
self.log("test_loss", loss) | |
self.log("test_acc", acc) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer, mode="min", factor=0.1, patience=5 | |
) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"monitor": "val_loss", | |
}, | |
} | |
# Data transformations | |
transform_train = transforms.Compose( | |
[ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
transform_test = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
) | |
# Load CIFAR-10 dataset | |
train_dataset = datasets.CIFAR10( | |
root="./data", train=True, download=True, transform=transform_train | |
) | |
val_dataset = datasets.CIFAR10( | |
root="./data", train=False, download=True, transform=transform_test | |
) | |
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4) | |
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4) | |
# Initialize the model | |
model = CIFAR10CNN() | |
# Define callbacks | |
checkpoint_callback = ModelCheckpoint( | |
dirpath="checkpoints", | |
monitor="val_loss", | |
filename="cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}", | |
save_top_k=3, | |
mode="min", | |
) | |
early_stopping = EarlyStopping( | |
monitor="val_loss", patience=5, mode="min", verbose=False | |
) | |
# Initialize the logger | |
logger = TensorBoardLogger(save_dir="lightning_logs", name="cifar10_cnn") | |
# Initialize the Trainer | |
trainer = L.Trainer( | |
max_epochs=50, | |
callbacks=[checkpoint_callback, early_stopping], | |
logger=logger, | |
accelerator="gpu" if torch.cuda.is_available() else "cpu", | |
devices="auto", | |
) | |
# Train the model | |
trainer.fit(model, train_loader, val_loader) | |
# Test the model | |
trainer.test(model, val_loader) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment