Skip to content

Instantly share code, notes, and snippets.

@AnandAwasthi
Last active September 3, 2021 13:17
Show Gist options
  • Save AnandAwasthi/7188ca3c2d1eabd4e7f7837c0fb49695 to your computer and use it in GitHub Desktop.
Save AnandAwasthi/7188ca3c2d1eabd4e7f7837c0fb49695 to your computer and use it in GitHub Desktop.
Convolutional Neural Network using Pytorch(Fashion-MNIST)
from __future__ import print_function
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.set_printoptions(linewidth = 120)
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
train_set = torchvision.datasets.FashionMNIST(
root = './data/FashionMNIST',
train = True,
download = True,
transform = transforms.Compose([
transforms.ToTensor()
])
)
train_loader = torch.utils.data.DataLoader(train_set, batch_size= 10)
# no of training sample
print(len(train_set))
# training labels
print(train_set.train_labels)
#frequency of label in training sample
print(train_set.train_labels.bincount())
sample = next(iter(train_set))
#length of sample
print(len(sample))
sample_image, sample_label = sample
# image size
print(sample_image.shape)
plt.imshow(sample_image.squeeze(), cmap='gray')
print('label:', sample_label)
batch = next(iter(train_loader))
images, labels = batch
print(images.shape)
#display 10 images in batch
grid = torchvision.utils.make_grid(images, nrow = 10)
plt.figure(figsize= (15,15))
plt.imshow(np.transpose(grid, (1,2,0)))
#build neural network
class FMnistNetwork(nn.Module):
def __init__(self):
super(FMnistNetwork, self).__init__()
self.conv1 = nn.Conv2d(in_channels = 1, out_channels=6, kernel_size = 5)
self.conv2 = nn.Conv2d(in_channels = 6, out_channels=12, kernel_size=5)
self.fc1 = nn.Linear(in_features=12*4*4, out_features= 120)
self.fc2 = nn.Linear(in_features = 120, out_features = 60)
self.out = nn.Linear(in_features= 60, out_features = 10)
def forward(self, tensor):
# hidden layer 1
tensor = self.conv1(tensor)
tensor = F.relu(tensor)
tensor = F.max_pool2d(tensor, kernel_size = 2, stride= 2)
# hidden layer 2
tensor = self.conv2(tensor)
tensor = F.relu(tensor)
tensor = F.max_pool2d(tensor, kernel_size = 2, stride = 2)
#hidden layer 3
tensor = tensor.reshape(-1, 12 * 4* 4)
tensor = self.fc1(tensor)
tensor = F.relu(tensor)
#hidden layer 4
tensor = self.fc2(tensor)
tensor = F.relu(tensor)
#output layer
tensor = self.out(tensor)
return tensor
# test network with one image
torch.set_grad_enabled(False)
fnn_dry_test_nn = FMnistNetwork()
print(fnn_dry_test_nn)
dry_test_sample = next(iter(train_set))
dry_test_image, dry_test_label = dry_test_sample
print(dry_test_image.shape)
#change dry_test_image rank as Batch, Channel, H, W
dry_test_image = dry_test_image.unsqueeze(0)
print(dry_test_image.shape)
dry_test_pred = fnn_dry_test_nn(dry_test_image)
print(dry_test_pred)
print(dry_test_label)
print(dry_test_pred.argmax(dim = 1))
# Output size formula
# n x n input
# f x f filter/kernel
# p is padding
# s is stride
# then output size will be (n - f + 2p)/s + 1
torch.set_grad_enabled(True)
fnn = FMnistNetwork()
optimizer = optim.Adam(fnn.parameters(), lr=0.001)
train_loader = torch.utils.data.DataLoader(train_set, batch_size= 100)
total_loss = 0
for epoch in range(5):
for batch in train_loader:
images, labels = batch
preds = fnn(images)
# calculate loss
loss = F.cross_entropy(preds, labels)
optimizer.zero_grad()
# calculate gradients
loss.backward()
# update weights
optimizer.step()
total_loss += loss.item()
print("epoch:", epoch, "loss:", total_loss)
@torch.no_grad()
def get_all_prediction(model, loader):
preds = torch.tensor([])
for batch in loader:
images, labels = batch
batch_predictions = model(images)
preds = torch.cat((preds, batch_predictions), dim = 0)
return preds
def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title=None,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Only use the labels that appear in the data
#classes = classes[unique_labels(y_true, y_pred)]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax
np.set_printoptions(precision=2)
train_preds = get_all_prediction(fnn, train_loader)
label_dict = {
'T-shirt/top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle boot',
}
# Plot non-normalized confusion matrix
plt.figure(figsize=(50,50))
plot_confusion_matrix(train_set.targets, train_preds.argmax(dim=1), classes=label_dict,
title='Confusion matrix')
plt.show()
print('Accuracy:', accuracy_score(train_set.targets, train_preds.argmax(dim=1)))
print('F1:', f1_score(train_set.targets, train_preds.argmax(dim=1), average='weighted'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment