Created
July 6, 2024 07:03
-
-
Save BexTuychiev/e96c81ec851567ba4306aea38ed4f3ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Hyperparameters | |
num_epochs = 50 | |
batch_size = 64 | |
learning_rate = 0.001 | |
# Data transformations | |
transform_train = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
transform_test = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
# Load CIFAR-10 dataset | |
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | |
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) | |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) | |
# Define the CNN architecture | |
class CIFAR10CNN(nn.Module): | |
def __init__(self): | |
super(CIFAR10CNN, self).__init__() | |
self.conv1 = nn.Conv2d(3, 32, 3, padding=1) | |
self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
self.conv3 = nn.Conv2d(64, 64, 3, padding=1) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.fc1 = nn.Linear(64 * 4 * 4, 512) | |
self.fc2 = nn.Linear(512, 10) | |
def forward(self, x): | |
x = self.pool(torch.relu(self.conv1(x))) | |
x = self.pool(torch.relu(self.conv2(x))) | |
x = self.pool(torch.relu(self.conv3(x))) | |
x = x.view(-1, 64 * 4 * 4) | |
x = torch.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
# Initialize the model, loss function, and optimizer | |
model = CIFAR10CNN().to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5) | |
# TensorBoard setup | |
writer = SummaryWriter('runs/cifar10_cnn_experiment') | |
# Training loop | |
total_step = len(train_loader) | |
for epoch in range(num_epochs): | |
model.train() | |
train_loss = 0.0 | |
for i, (images, labels) in enumerate(train_loader): | |
images = images.to(device) | |
labels = labels.to(device) | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
# Backward and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
if (i+1) % 100 == 0: | |
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}') | |
# Calculate average training loss for the epoch | |
avg_train_loss = train_loss / len(train_loader) | |
writer.add_scalar('training loss', avg_train_loss, epoch) | |
# Validation | |
model.eval() | |
with torch.no_grad(): | |
correct = 0 | |
total = 0 | |
val_loss = 0.0 | |
for images, labels in test_loader: | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
val_loss += loss.item() | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
accuracy = 100 * correct / total | |
avg_val_loss = val_loss / len(test_loader) | |
print(f'Validation Accuracy: {accuracy:.2f}%') | |
writer.add_scalar('validation loss', avg_val_loss, epoch) | |
writer.add_scalar('validation accuracy', accuracy, epoch) | |
# Learning rate scheduling | |
scheduler.step(avg_val_loss) | |
# Final test | |
model.eval() | |
with torch.no_grad(): | |
correct = 0 | |
total = 0 | |
for images, labels in test_loader: | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print(f'Test Accuracy: {100 * correct / total:.2f}%') | |
writer.close() | |
# Save the model | |
torch.save(model.state_dict(), 'cifar10_cnn.pth') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment