/pytorch_tldr.py Secret
Created
June 8, 2022 21:54
A one file summary of using PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from matplotlib import pyplot as plt | |
from torch import Tensor | |
from torch.nn import CrossEntropyLoss | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import SGD | |
from torch.utils.data import DataLoader | |
from torchvision.datasets import CIFAR10 | |
from torchvision.utils import make_grid | |
BATCH_SIZE = 4 | |
MINI_BATCH_SIZE = 2000 | |
EPOCHS = 2 | |
DATA_ROOT = './data' | |
DATA_CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', | |
'ship', 'truck') | |
MODEL_PATH = './cifar_net.pth' | |
class Net(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(3, 6, 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(6, 16, 5) | |
self.fc1 = nn.Linear(16 * 5 * 5, 120) | |
self.fc2 = nn.Linear(120, 84) | |
self.fc3 = nn.Linear(84, 10) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = torch.flatten(x, 1) # flatten all dimensions except batch | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
def main(): | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
# Assuming that we are on a CUDA machine, this should print a CUDA device: | |
print(device) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
train_data = CIFAR10(root=DATA_ROOT, train=True, download=True, | |
transform=transform) | |
test_data = CIFAR10(root=DATA_ROOT, train=False, download=True, | |
transform=transform) | |
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, | |
shuffle=True, num_workers=2) | |
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, | |
shuffle=False, num_workers=2) | |
train_data_iterator = iter(train_loader) | |
images, labels = train_data_iterator.next() | |
print(' '.join(f'{DATA_CLASSES[labels[i]]:5s}' for i in range(BATCH_SIZE))) | |
show_image(make_grid(images)) | |
model = Net() | |
model.to(device) | |
criterion = CrossEntropyLoss() | |
optimizer = SGD(model.parameters(), lr=1e-3, momentum=0.9) | |
for epoch in range(EPOCHS): | |
running_loss = 0.0 | |
for i, data in enumerate(train_loader, 0): | |
inputs, labels = data[0].to(device), data[1].to(device) | |
# inputs, labels = data | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
if i % MINI_BATCH_SIZE == MINI_BATCH_SIZE - 1: | |
print( | |
f'[epoch={epoch + 1}, {i + 1:5d}] loss: {running_loss / MINI_BATCH_SIZE:.3f}') | |
running_loss = 0.0 | |
print('Training is complete!') | |
# Save the model | |
torch.save(model.state_dict(), MODEL_PATH) | |
# Test the model | |
test_data_iterator = iter(test_loader) | |
test_images, labels = test_data_iterator.next() | |
classes = " ".join(f'{DATA_CLASSES[labels[i]]:5s}' for i in range(4)) | |
print(f'Ground truth: {classes}') | |
show_image(make_grid(test_images)) | |
_, predicted = torch.max(outputs, 1) | |
print('Predicted: ', ' '.join(f'{DATA_CLASSES[predicted[j]]:5s}' | |
for j in range(4))) | |
correct = 0 | |
total = 0 | |
# since we're not training, we don't need to calculate the gradients for | |
# our outputs | |
with torch.no_grad(): | |
for data in test_loader: | |
inputs, labels = data[0].to(device), data[1].to(device) | |
# inputs, labels = data | |
# calculate outputs by running images through the network | |
outputs = model(images) | |
# the class with the highest energy is what we choose as prediction | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print( | |
f'Accuracy of the network on the 10000 test images: {100 * correct // total} %') | |
# Get more detailed accuracy info | |
correct_predictions = {classname: 0 for classname in classes} | |
total_predictions = {classname: 0 for classname in classes} | |
with torch.no_grad(): | |
for data in test_loader: | |
inputs, labels = data[0].to(device), data[1].to(device) | |
# inputs, labels = data | |
outputs = model(images) | |
_, predictions = torch.max(outputs, 1) | |
for label, prediction in zip(labels, predictions): | |
if label == prediction: | |
correct_predictions[classes[label]] += 1 | |
total_predictions[classes[label]] += 1 | |
# Accuracy for each class | |
for classname, correct_count in correct_predictions.items(): | |
accuracy = 100 * float(correct_count) / total_predictions[classname] | |
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %') | |
def show_image(image: Tensor): | |
image = image / 2 + 0.5 # Un-normalize, apparently? | |
np_image = image.numpy() | |
plt.imshow(np.transpose(np_image, (1, 2, 0))) | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment