Skip to content

Instantly share code, notes, and snippets.

@FreeFly19
Created July 27, 2023 16:13
Show Gist options
  • Save FreeFly19/9554211dd7bfe74ea6a3ee900f42ad13 to your computer and use it in GitHub Desktop.
Save FreeFly19/9554211dd7bfe74ea6a3ee900f42ad13 to your computer and use it in GitHub Desktop.
MNIST PyTorch CNN with skip connections
import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
train_dataset = torchvision.datasets.MNIST('data', train=True, transform=ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST('data', train=False, transform=ToTensor(), download=True)
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
class Model(torch.nn.Module):
def __init__(self, base_filter_size=16):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, base_filter_size, kernel_size=(3, 3), padding=1)
self.conv2_1 = torch.nn.Conv2d(base_filter_size, base_filter_size, kernel_size=(3, 3), padding=1)
self.conv2_2 = torch.nn.Conv2d(base_filter_size, base_filter_size, kernel_size=(3, 3), padding=1)
self.conv3 = torch.nn.Conv2d(base_filter_size, base_filter_size * 2, kernel_size=(3, 3), padding=1)
self.act = torch.nn.ReLU()
self.max_pooling = torch.nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = torch.nn.Linear(base_filter_size * 2 * 3 * 3, 10)
def forward(self, X):
X = self.conv1(X)
X = self.act(X)
X = self.max_pooling(X)
# [batch, base_filter_size, 14, 14]
X = self.conv2_1(X) + X
X = self.act(X)
X = self.conv2_2(X) + X
X = self.act(X)
X = self.max_pooling(X)
# [batch, base_filter_size, 7, 7]
X = self.conv3(X)
X = self.act(X)
X = self.max_pooling(X)
# [batch, base_filter_size*2, 3, 3]
X = torch.flatten(X, start_dim=1)
# [batch, base_filter_size*2*3*3]
X = self.fc(X)
# [batch, 10]
return X
model = Model().cuda()
critic = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(10):
for mode in ['train', 'val']:
losses = []
accuracies = []
dataloader = train_dataloader if mode == 'train' else test_dataloader
for X, y in dataloader:
X = X.cuda()
y = y.cuda()
optim.zero_grad()
y_pred = model(X)
accuracy = (torch.argmax(y_pred, dim=1) == y).float().mean()
accuracies.append(accuracy.item())
y_ohe = torch.zeros_like(y_pred)
for i in range(y.shape[0]):
y_ohe[i][y[i]] = 1
loss = critic(y_pred, y_ohe)
losses.append(loss.item())
if mode == 'train':
loss.backward()
optim.step()
print(f'Mode: {mode}, Epoch: {epoch}, Loss: {np.mean(losses):.3f}, Accuracy: {np.mean(accuracies):.3f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment