Skip to content

Instantly share code, notes, and snippets.

@BexTuychiev
Created July 6, 2024 07:03
Show Gist options
  • Save BexTuychiev/e96c81ec851567ba4306aea38ed4f3ac to your computer and use it in GitHub Desktop.
Save BexTuychiev/e96c81ec851567ba4306aea38ed4f3ac to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyperparameters
num_epochs = 50
batch_size = 64
learning_rate = 0.001
# Data transformations
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# Define the CNN architecture
class CIFAR10CNN(nn.Module):
def __init__(self):
super(CIFAR10CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = self.pool(torch.relu(self.conv3(x)))
x = x.view(-1, 64 * 4 * 4)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Initialize the model, loss function, and optimizer
model = CIFAR10CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
# TensorBoard setup
writer = SummaryWriter('runs/cifar10_cnn_experiment')
# Training loop
total_step = len(train_loader)
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')
# Calculate average training loss for the epoch
avg_train_loss = train_loss / len(train_loader)
writer.add_scalar('training loss', avg_train_loss, epoch)
# Validation
model.eval()
with torch.no_grad():
correct = 0
total = 0
val_loss = 0.0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
avg_val_loss = val_loss / len(test_loader)
print(f'Validation Accuracy: {accuracy:.2f}%')
writer.add_scalar('validation loss', avg_val_loss, epoch)
writer.add_scalar('validation accuracy', accuracy, epoch)
# Learning rate scheduling
scheduler.step(avg_val_loss)
# Final test
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
writer.close()
# Save the model
torch.save(model.state_dict(), 'cifar10_cnn.pth')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment