Skip to content

Instantly share code, notes, and snippets.

@ecsplendid
Last active May 19, 2024 10:57
Show Gist options
  • Save ecsplendid/d4da1199a99d7521ecbba3cdfee154c5 to your computer and use it in GitHub Desktop.
Save ecsplendid/d4da1199a99d7521ecbba3cdfee154c5 to your computer and use it in GitHub Desktop.
Hinton MNIST example with resnet
// generated with gpt4-o, probably still buggy
// testing what Hinton spoke about here https://youtu.be/tP-4njhyGvo?si=9JCVwyiftFayc6mA&t=857
// i.e. 50% label noise on train
// CNN, ~10^8 params i.e. in overparam regime for MNIST, tried adding regularisation
# Changes made to the original code:
# 1. Replaced the CNN architecture with a ResNet-based model (MNIST_ResNet) for state-of-the-art performance.
# 2. Incorporated advanced data augmentation techniques: RandomResizedCrop, RandomHorizontalFlip, and RandomErasing.
# 3. Added label smoothing to the loss function to prevent overconfidence in the model.
# 4. Ensured compatibility with Apple's M1/M2 chips using MPS.
# 5. Used Adam optimizer with L2 regularization (weight decay).
# 6. Adjusted data preprocessing and loading to maintain consistency with the new architecture and augmentations.
# 7. Fixed deprecated parameters for pretrained models and handled AMP and CUDA warnings.
# 8. Ensured transformations are compatible with PIL images by converting to tensors early in the pipeline.
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np
import os
# Define transformations for the training and validation sets
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.RandomResizedCrop(28),
transforms.RandomHorizontalFlip(),
transforms.RandomErasing(p=0.1),
transforms.Normalize((0.5,), (0.5,))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Define hyperparameters
batch_size = 64
validation_split = 0.1
# Download and prepare the datasets
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)
# Split the dataset into training and validation sets
train_size = int((1 - validation_split) * len(mnist_dataset))
val_size = len(mnist_dataset) - train_size
train_dataset, val_dataset = random_split(mnist_dataset, [train_size, val_size])
# Preprocess data
train_images = train_dataset.dataset.data[train_dataset.indices].view(-1, 28*28).float() / 255.0
train_labels = train_dataset.dataset.targets[train_dataset.indices]
# Shuffle the training set
indices = torch.randperm(len(train_labels))
train_images_shuffled = train_images[indices]
train_labels_shuffled = train_labels[indices]
# Randomize half of the labels to be wrong in the training set
random_train_labels = train_labels_shuffled.clone()
num_labels = len(train_labels_shuffled)
num_randomize = num_labels // 2
random_indices = np.random.choice(num_labels, num_randomize, replace=False)
for i in random_indices:
possible_labels = list(range(10))
possible_labels.remove(train_labels_shuffled[i].item())
random_train_labels[i] = np.random.choice(possible_labels)
# Create a dataset with shuffled images and partially shuffled labels
shuffled_train_dataset = TensorDataset(train_images_shuffled, random_train_labels)
# Data loaders
train_loader = DataLoader(shuffled_train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Define the ResNet model for MNIST
class MNIST_ResNet(nn.Module):
def __init__(self):
super(MNIST_ResNet, self).__init__()
self.resnet = torchvision.models.resnet18(weights=None, num_classes=10)
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
def forward(self, x):
return self.resnet(x)
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model = MNIST_ResNet().to(device)
# Define loss function and optimizer with weight decay for L2 regularization
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
num_epochs = 2000 # Adjust the number of epochs as necessary
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
def calculate_accuracy(loader, model):
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images = images.view(-1, 1, 28, 28).to(device) # Reshape to original dimensions
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
# Open a log file to write the loss data incrementally
with open('training_log_resnet.txt', 'w') as log_file:
log_file.write('Epoch,Train Loss,Validation Loss,Train Accuracy,Validation Accuracy\n')
for epoch in range(num_epochs):
scaler = GradScaler(enabled=device.type == 'cuda')
model.train()
train_loss = 0
for images, labels in train_loader:
images = images.view(-1, 1, 28, 28).to(device) # Reshape to original dimensions
labels = labels.to(device)
optimizer.zero_grad()
with autocast(enabled=device.type == 'cuda'):
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
train_loss /= len(train_loader)
train_losses.append(train_loss)
train_accuracy = calculate_accuracy(train_loader, model)
train_accuracies.append(train_accuracy)
model.eval()
val_loss = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.view(-1, 1, 28, 28).to(device) # Reshape to original dimensions
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
val_loss /= len(val_loader)
val_losses.append(val_loss)
val_accuracy = calculate_accuracy(val_loader, model)
val_accuracies.append(val_accuracy)
log_file.write(f'{epoch + 1},{train_loss:.4f},{val_loss:.4f},{train_accuracy:.2f},{val_accuracy:.2f}\n')
log_file.flush() # Ensure data is written incrementally
print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Validation Accuracy: {val_accuracy:.2f}%')
# Plotting losses and accuracies
fig, ax1 = plt.subplots()
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.plot(range(1, num_epochs + 1), train_losses, label='Train Loss', color='tab:blue')
ax1.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss', color='tab:orange')
ax1.tick_params(axis='y')
ax2 = ax1.twinx()
ax2.set_ylabel('Accuracy')
ax2.plot(range(1, num_epochs + 1), train_accuracies, label='Train Accuracy', color='tab:green')
ax2.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy', color='tab:red')
ax2.tick_params(axis='y')
fig.tight_layout()
fig.legend(loc='upper right', bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes)
plt.title('Loss and Accuracy over Epochs')
# Save the plot as a JPG file
plt.savefig('training_plot_resnet.jpg', format='jpg', dpi=300)
plt.show()
# Calculate test accuracy
test_accuracy = calculate_accuracy(test_loader, model)
# Log the test accuracy
with open('training_log_resnet.txt', 'a') as log_file: # Append to the log file
log_file.write(f'Test Accuracy: {test_accuracy:.2f}%\n')
print(f'Test Accuracy: {test_accuracy:.2f}%')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment