Skip to content

Instantly share code, notes, and snippets.

@ShoaibMerajSami
Forked from Mahedi-61/autoencoder_mnist.py
Created November 16, 2021 16:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ShoaibMerajSami/311b5b44fe89f84e1b005cc0e7e023af to your computer and use it in GitHub Desktop.
Save ShoaibMerajSami/311b5b44fe89f84e1b005cc0e7e023af to your computer and use it in GitHub Desktop.
3-layer autoencoder for MNIST digit reconstruction and classification with visualization (Pytorch)
"""
Course: Applications of NN (CpE - 520)
Homework Assignment 9
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
from torchvision.utils import save_image
import tqdm
from sklearn.metrics import confusion_matrix
import os
# classification acc
# 60 nodes: 0.9807
# 50 nodes: 0.9782
# 40 nodes: 0.9763
# 30 nodes; 0.9739
# 20 nodes; 0.9685
# 10 nodes: 0.9463
# 5 nodes; 0.9010
# reconstruction loss (MSE)
# 60 nodes: 0.5000
# 30 nodes: 0.5198
# 10 nodes: 0.5937
# 5 nodes: 0.6529
task = "re_con" #classify
img_size = 28
hidden_nodes = 30
data_dir = "./data"
saved_image_dir = "./images"
batch_size = 500
lr = 3e-3
num_epochs = 200
is_save = True
is_load = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_file = "re_%d_cl_model.pth" % hidden_nodes
class AutoEncoder(nn.Module):
def __init__(self, img_size, hidden_nodes):
super().__init__()
self.input_node = img_size * img_size
self.hidden_node = hidden_nodes
self.linear = nn.Linear(self.input_node, self.hidden_node)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.autoencoder = nn.Linear(self.hidden_node, self.input_node)
self.classifier = nn.Linear(self.hidden_node, 10)
def forward(self, x):
x = x.view(x.size(0), 784)
x = self.relu(self.linear(x))
image = self.autoencoder(x)
output = self.classifier(x)
return self.sigmoid(image), output
class MNISTDataset(Dataset):
def __init__(self, data_dir, train=True):
super().__init__()
self.images, self.labels = self.get_data(data_dir, train)
mean = (0.1307, )
std = (0.3081, )
self.train_trans = transforms.Compose([
transforms.RandomRotation((0, 10), fill=(0, )),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
self.test_trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
self.trans = (self.train_trans if train else self.test_trans)
def get_data(self, data_dir, train):
images = []
labels = []
if train == True:
dataset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
for img, label in dataset:
images.append(img)
labels.append(label)
elif train == False:
dataset = torchvision.datasets.MNIST(root=data_dir, train=False, download=True)
for img, label in dataset:
images.append(img)
labels.append(label)
return images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img = self.images[index]
img = self.trans(img)
label = np.array(self.labels[index], dtype=np.float)
return img, label
class Train(nn.Module):
def __init__(self):
super().__init__()
self.train_loader = DataLoader(MNISTDataset(data_dir, train=True),
batch_size=batch_size,
shuffle=True, num_workers=4)
self.test_loader = DataLoader(MNISTDataset(data_dir, train=False),
batch_size=batch_size,
shuffle=True, num_workers=4)
self.model = AutoEncoder(img_size = img_size,
hidden_nodes=hidden_nodes).to(device)
self.classifier_loss = nn.CrossEntropyLoss()
self.MSE = nn.MSELoss()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
self.schedular = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.92)
self.iter = 0
if is_load:
checkpoint = torch.load(model_file)
self.model.load_state_dict(checkpoint["model"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
def train(self):
train_loop = tqdm.tqdm(self.train_loader)
for e in range(num_epochs):
total_loss = 0
for img, label in train_loop:
img = img.to(device)
label = torch.tensor(label, dtype = torch.long).to(device)
reconstructed_img, scores = self.model(img)
reconstructed_img = reconstructed_img.view(-1, 1, 28, 28)
if task == "re_con":
loss = self.MSE(reconstructed_img, img)
elif task == "classify":
loss = self.classifier_loss(scores, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
print("Epoch {} | total loss {:.4f}".format(e, total_loss / len(self.train_loader)))
if (e % 10 == 0 and e !=0):
self.schedular.step()
print(self.schedular.get_lr())
# change according to your experiment
if task == "classify":
if (e > 100 and e % 10 == 0):
self.test()
if is_save:
checkpoint = {}
checkpoint["model"] = self.model.state_dict()
checkpoint["optimizer"] = self.optimizer.state_dict()
torch.save(checkpoint, model_file)
print("saving model")
if task == "re_con":
self.do_reconstruction()
def test(self):
self.model.eval()
result = 0
for img, label in self.test_loader:
img = img.to(device)
_, scores = self.model(img)
output = torch.argmax(scores, dim=1)
output = output.cpu().detach().numpy()
result += sum(output[i] == label[i] for i in range(len(label)))
result = result.cpu().detach().numpy()
acc = result / (len(self.test_loader) * batch_size)
print("accuracy: {:.4f}".format(acc))
self.model.train()
def plot_acc_vs_hidden_nodes(self, hidden_nodes, acc):
plt.plot(hidden_nodes, acc, 'g', label='Test accuracy')
plt.title('Test accuracy on MNIST dataset')
plt.xlabel('Number of hidden nodes in the autoencoder')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
def draw_confusion_matrix(self):
self.model.eval()
true_labels = []
pred_labels = []
for img, label in self.test_loader:
img = img.to(device)
_, scores = self.model(img)
output = torch.argmax(scores, dim=1).cpu().detach().tolist()
pred_labels += output
true_labels += label
# make confusion matrix
c_matrix = confusion_matrix(y_true=true_labels, y_pred=pred_labels)
plt.figure(figsize = (10, 12))
sns.heatmap(c_matrix, annot=True, fmt = 'g', linewidths=.5)
# labels, title
plt.xlabel('Predicted Label', fontsize=10, labelpad=11)
plt.ylabel('True Label', fontsize=10)
b, t = plt.ylim() # discover the values for bottom and top
b += 0.5 # Add 0.5 to the bottom
t -= 0.5 # Subtract 0.5 from the top
plt.ylim(b, t) # update the ylim(bottom, top) values
plt.show()
def display_weights_as_image(self):
print("saving weight images")
img_dir = os.path.join("w_images", str(hidden_nodes))
os.makedirs(img_dir, exist_ok=True)
for index in range(hidden_nodes):
img = self.model.linear.weight[index]
img = img.view(28, 28).cpu().detach().numpy()
plt.imshow(img, cmap="gray")
plt.savefig(os.path.join(img_dir, str(index) + ".png"))
#plt.show()
def do_reconstruction(self):
print("reconstructing images ...")
iter = self.train_loader.__iter__()
img_dir = os.path.join(saved_image_dir, str(hidden_nodes))
os.makedirs(img_dir, exist_ok=True)
while True:
img, label = iter.__next__()
re_img, _ = self.model(img.to(device))
label = label.detach().tolist()
if set(label) == set(float(i) for i in range(0, 10)):
digit_indexs = [label.index(float(i)) for i in range(0, 10)]
break
for index in digit_indexs:
img = re_img[index]
img = img.cpu().detach().numpy()
img = img.reshape(28, 28)
plt.imshow(img, cmap="gray")
plt.savefig(os.path.join(img_dir, str(index) + ".png"))
#plt.show()
if __name__ == "__main__":
t = Train()
#t.train()
t.display_weights_as_image()
#hidden_nodes = [5, 10, 20, 30, 40, 50, 60]
#acc = [0.9010, 0.9463, 0.9685, 0.9739, 0.9763, 0.9782, 0.9807]
#t.plot_acc_vs_hidden_nodes(hidden_nodes, acc)
#t.draw_confusion_matrix()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment