Skip to content

Instantly share code, notes, and snippets.

@Mahedi-61
Last active November 18, 2021 20:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save Mahedi-61/e70f08e1f36aa9a4fa575d2a5a3f6c25 to your computer and use it in GitHub Desktop.
Save Mahedi-61/e70f08e1f36aa9a4fa575d2a5a3f6c25 to your computer and use it in GitHub Desktop.
3-layer autoencoder for MNIST digit reconstruction and classification with visualization (Pytorch)
"""
Course: Applications of NN (CpE - 520)
Homework Assignment 9
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
import tqdm
from sklearn.metrics import confusion_matrix
import os
task = "classify" # re_con
img_size = 28
hidden_nodes = 60
data_dir = "./data"
saved_image_dir = "./images"
batch_size = 500
lr = 3e-3
num_epochs = 200
is_save = True
is_load = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if task == "classify":
model_file = "h%d_cl_model.pth" % hidden_nodes
elif task == "re_con":
model_file = "re_%d_cl_model.pth" % hidden_nodes
class AutoEncoder(nn.Module):
def __init__(self, img_size, hidden_nodes):
super().__init__()
self.input_node = img_size * img_size
self.hidden_node = hidden_nodes
self.linear = nn.Linear(self.input_node, self.hidden_node)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.autoencoder = nn.Linear(self.hidden_node, self.input_node)
self.classifier = nn.Linear(self.hidden_node, 10)
def forward(self, x):
x = x.view(x.size(0), 784)
x = self.relu(self.linear(x))
image = self.autoencoder(x)
output = self.classifier(x)
return self.sigmoid(image), output
class MNISTDataset(Dataset):
def __init__(self, data_dir, train=True):
super().__init__()
self.images, self.labels = self.get_data(data_dir, train)
mean = (0.1307, )
std = (0.3081, )
self.train_trans = transforms.Compose([
transforms.RandomRotation((0, 10), fill=(0, )),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
self.test_trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
self.trans = (self.train_trans if train else self.test_trans)
def get_data(self, data_dir, train):
images = []
labels = []
if train == True:
dataset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True)
for img, label in dataset:
images.append(img)
labels.append(label)
elif train == False:
dataset = torchvision.datasets.MNIST(root=data_dir, train=False, download=True)
for img, label in dataset:
images.append(img)
labels.append(label)
return images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img = self.images[index]
img = self.trans(img)
label = np.array(self.labels[index], dtype=np.float)
return img, label
class Train(nn.Module):
def __init__(self):
super().__init__()
self.train_loader = DataLoader(MNISTDataset(data_dir, train=True),
batch_size=batch_size,
shuffle=True, num_workers=4)
self.test_loader = DataLoader(MNISTDataset(data_dir, train=False),
batch_size=batch_size,
shuffle=True, num_workers=4)
self.model = AutoEncoder(img_size = img_size,
hidden_nodes=hidden_nodes).to(device)
self.classifier_loss = nn.CrossEntropyLoss()
self.MSE = nn.MSELoss()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
self.schedular = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.92)
self.iter = 0
if is_load:
checkpoint = torch.load(model_file)
self.model.load_state_dict(checkpoint["model"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
def train(self):
train_loop = tqdm.tqdm(self.train_loader)
for e in range(num_epochs):
total_loss = 0
for img, label in train_loop:
img = img.to(device)
label = torch.tensor(label, dtype = torch.long).to(device)
reconstructed_img, scores = self.model(img)
reconstructed_img = reconstructed_img.view(-1, 1, 28, 28)
if task == "re_con":
loss = self.MSE(reconstructed_img, img)
elif task == "classify":
loss = self.classifier_loss(scores, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
print("Epoch {} | total loss {:.4f}".format(e, total_loss / len(self.train_loader)))
if (e % 10 == 0 and e !=0):
self.schedular.step()
print(self.schedular.get_lr())
# change according to your experiment
if task == "classify":
if (e > 100 and e % 10 == 0):
self.test()
if is_save:
checkpoint = {}
checkpoint["model"] = self.model.state_dict()
checkpoint["optimizer"] = self.optimizer.state_dict()
torch.save(checkpoint, model_file)
print("saving model")
if task == "re_con":
self.do_reconstruction()
def test(self):
self.model.eval()
result = 0
for img, label in self.test_loader:
img = img.to(device)
_, scores = self.model(img)
output = torch.argmax(scores, dim=1)
output = output.cpu().detach().numpy()
result += sum(output[i] == label[i] for i in range(len(label)))
result = result.cpu().detach().numpy()
acc = result / (len(self.test_loader) * batch_size)
print("accuracy: {:.4f}".format(acc))
self.model.train()
def plot_acc_vs_hidden_nodes(self, hidden_nodes, acc):
plt.plot(hidden_nodes, acc, 'g', label='Test accuracy')
plt.title('Test accuracy on MNIST dataset')
plt.xlabel('Number of hidden nodes in the autoencoder')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
def draw_confusion_matrix(self):
self.model.eval()
true_labels = []
pred_labels = []
for img, label in self.test_loader:
img = img.to(device)
_, scores = self.model(img)
output = torch.argmax(scores, dim=1).cpu().detach().tolist()
pred_labels += output
true_labels += label
# make confusion matrix
c_matrix = confusion_matrix(y_true=true_labels, y_pred=pred_labels)
plt.figure(figsize = (10, 10))
sns. set(font_scale=1.4)
sns.heatmap(c_matrix, annot=True, fmt = 'g', linewidths=.5)
# labels, title
plt.xlabel('Predicted Label', fontsize=10, labelpad=11)
plt.ylabel('True Label', fontsize=10)
b, t = plt.ylim() # discover the values for bottom and top
b += 0.5 # Add 0.5 to the bottom
t -= 0.5 # Subtract 0.5 from the top
plt.ylim(b, t) # update the ylim(bottom, top) values
plt.show()
def display_weights_as_image(self):
print("saving weight images")
img_dir = os.path.join("w_images", str(hidden_nodes))
os.makedirs(img_dir, exist_ok=True)
for index in range(hidden_nodes):
img = self.model.linear.weight[index]
img = img.view(28, 28).cpu().detach().numpy()
plt.imshow(img, cmap="gray")
plt.savefig(os.path.join(img_dir, str(index) + ".png"))
#plt.show()
def do_reconstruction(self):
print("reconstructing images ...")
iter = self.train_loader.__iter__()
img_dir = os.path.join(saved_image_dir, str(hidden_nodes))
os.makedirs(img_dir, exist_ok=True)
while True:
img, label = iter.__next__()
re_img, _ = self.model(img.to(device))
label = label.detach().tolist()
if set(label) == set(float(i) for i in range(0, 10)):
digit_indexs = [label.index(float(i)) for i in range(0, 10)]
break
for index in digit_indexs:
img = re_img[index]
img = img.cpu().detach().numpy()
img = img.reshape(28, 28)
plt.imshow(img, cmap="gray")
plt.savefig(os.path.join(img_dir, str(index) + ".png"))
if __name__ == "__main__":
t = Train()
t.train()
#hidden_nodes = [5, 10, 20, 30, 40, 50, 60]
#acc = [0.9010, 0.9463, 0.9685, 0.9739, 0.9763, 0.9782, 0.9807]
#t.plot_acc_vs_hidden_nodes(hidden_nodes, acc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment