Last active
May 19, 2024 10:57
-
-
Save ecsplendid/d4da1199a99d7521ecbba3cdfee154c5 to your computer and use it in GitHub Desktop.
Hinton MNIST example with resnet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// generated with gpt4-o, probably still buggy | |
// testing what Hinton spoke about here https://youtu.be/tP-4njhyGvo?si=9JCVwyiftFayc6mA&t=857 | |
// i.e. 50% label noise on train | |
// CNN, ~10^8 params i.e. in overparam regime for MNIST, tried adding regularisation | |
# Changes made to the original code: | |
# 1. Replaced the CNN architecture with a ResNet-based model (MNIST_ResNet) for state-of-the-art performance. | |
# 2. Incorporated advanced data augmentation techniques: RandomResizedCrop, RandomHorizontalFlip, and RandomErasing. | |
# 3. Added label smoothing to the loss function to prevent overconfidence in the model. | |
# 4. Ensured compatibility with Apple's M1/M2 chips using MPS. | |
# 5. Used Adam optimizer with L2 regularization (weight decay). | |
# 6. Adjusted data preprocessing and loading to maintain consistency with the new architecture and augmentations. | |
# 7. Fixed deprecated parameters for pretrained models and handled AMP and CUDA warnings. | |
# 8. Ensured transformations are compatible with PIL images by converting to tensors early in the pipeline. | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.cuda.amp import autocast, GradScaler | |
import matplotlib.pyplot as plt | |
from torch.utils.data import DataLoader, TensorDataset, random_split | |
import numpy as np | |
import os | |
# Define transformations for the training and validation sets | |
transform_train = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.RandomResizedCrop(28), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomErasing(p=0.1), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
transform_test = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
# Define hyperparameters | |
batch_size = 64 | |
validation_split = 0.1 | |
# Download and prepare the datasets | |
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train) | |
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test) | |
# Split the dataset into training and validation sets | |
train_size = int((1 - validation_split) * len(mnist_dataset)) | |
val_size = len(mnist_dataset) - train_size | |
train_dataset, val_dataset = random_split(mnist_dataset, [train_size, val_size]) | |
# Preprocess data | |
train_images = train_dataset.dataset.data[train_dataset.indices].view(-1, 28*28).float() / 255.0 | |
train_labels = train_dataset.dataset.targets[train_dataset.indices] | |
# Shuffle the training set | |
indices = torch.randperm(len(train_labels)) | |
train_images_shuffled = train_images[indices] | |
train_labels_shuffled = train_labels[indices] | |
# Randomize half of the labels to be wrong in the training set | |
random_train_labels = train_labels_shuffled.clone() | |
num_labels = len(train_labels_shuffled) | |
num_randomize = num_labels // 2 | |
random_indices = np.random.choice(num_labels, num_randomize, replace=False) | |
for i in random_indices: | |
possible_labels = list(range(10)) | |
possible_labels.remove(train_labels_shuffled[i].item()) | |
random_train_labels[i] = np.random.choice(possible_labels) | |
# Create a dataset with shuffled images and partially shuffled labels | |
shuffled_train_dataset = TensorDataset(train_images_shuffled, random_train_labels) | |
# Data loaders | |
train_loader = DataLoader(shuffled_train_dataset, batch_size=batch_size, shuffle=True) | |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) | |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
# Define the ResNet model for MNIST | |
class MNIST_ResNet(nn.Module): | |
def __init__(self): | |
super(MNIST_ResNet, self).__init__() | |
self.resnet = torchvision.models.resnet18(weights=None, num_classes=10) | |
self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
def forward(self, x): | |
return self.resnet(x) | |
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") | |
model = MNIST_ResNet().to(device) | |
# Define loss function and optimizer with weight decay for L2 regularization | |
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) | |
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4) | |
num_epochs = 2000 # Adjust the number of epochs as necessary | |
train_losses = [] | |
val_losses = [] | |
train_accuracies = [] | |
val_accuracies = [] | |
def calculate_accuracy(loader, model): | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for images, labels in loader: | |
images = images.view(-1, 1, 28, 28).to(device) # Reshape to original dimensions | |
labels = labels.to(device) | |
outputs = model(images) | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
return 100 * correct / total | |
# Open a log file to write the loss data incrementally | |
with open('training_log_resnet.txt', 'w') as log_file: | |
log_file.write('Epoch,Train Loss,Validation Loss,Train Accuracy,Validation Accuracy\n') | |
for epoch in range(num_epochs): | |
scaler = GradScaler(enabled=device.type == 'cuda') | |
model.train() | |
train_loss = 0 | |
for images, labels in train_loader: | |
images = images.view(-1, 1, 28, 28).to(device) # Reshape to original dimensions | |
labels = labels.to(device) | |
optimizer.zero_grad() | |
with autocast(enabled=device.type == 'cuda'): | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
train_loss += loss.item() | |
train_loss /= len(train_loader) | |
train_losses.append(train_loss) | |
train_accuracy = calculate_accuracy(train_loader, model) | |
train_accuracies.append(train_accuracy) | |
model.eval() | |
val_loss = 0 | |
with torch.no_grad(): | |
for images, labels in val_loader: | |
images = images.view(-1, 1, 28, 28).to(device) # Reshape to original dimensions | |
labels = labels.to(device) | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
val_loss += loss.item() | |
val_loss /= len(val_loader) | |
val_losses.append(val_loss) | |
val_accuracy = calculate_accuracy(val_loader, model) | |
val_accuracies.append(val_accuracy) | |
log_file.write(f'{epoch + 1},{train_loss:.4f},{val_loss:.4f},{train_accuracy:.2f},{val_accuracy:.2f}\n') | |
log_file.flush() # Ensure data is written incrementally | |
print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Validation Accuracy: {val_accuracy:.2f}%') | |
# Plotting losses and accuracies | |
fig, ax1 = plt.subplots() | |
ax1.set_xlabel('Epoch') | |
ax1.set_ylabel('Loss') | |
ax1.plot(range(1, num_epochs + 1), train_losses, label='Train Loss', color='tab:blue') | |
ax1.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss', color='tab:orange') | |
ax1.tick_params(axis='y') | |
ax2 = ax1.twinx() | |
ax2.set_ylabel('Accuracy') | |
ax2.plot(range(1, num_epochs + 1), train_accuracies, label='Train Accuracy', color='tab:green') | |
ax2.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy', color='tab:red') | |
ax2.tick_params(axis='y') | |
fig.tight_layout() | |
fig.legend(loc='upper right', bbox_to_anchor=(1,1), bbox_transform=ax1.transAxes) | |
plt.title('Loss and Accuracy over Epochs') | |
# Save the plot as a JPG file | |
plt.savefig('training_plot_resnet.jpg', format='jpg', dpi=300) | |
plt.show() | |
# Calculate test accuracy | |
test_accuracy = calculate_accuracy(test_loader, model) | |
# Log the test accuracy | |
with open('training_log_resnet.txt', 'a') as log_file: # Append to the log file | |
log_file.write(f'Test Accuracy: {test_accuracy:.2f}%\n') | |
print(f'Test Accuracy: {test_accuracy:.2f}%') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment