3-layer autoencoder for MNIST digit reconstruction and classification with visualization (Pytorch)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Course: Applications of NN (CpE - 520) | |
Homework Assignment 9 | |
""" | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import torch | |
from torch import nn | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
import torchvision | |
import tqdm | |
from sklearn.metrics import confusion_matrix | |
import os | |
task = "classify" # re_con | |
img_size = 28 | |
hidden_nodes = 60 | |
data_dir = "./data" | |
saved_image_dir = "./images" | |
batch_size = 500 | |
lr = 3e-3 | |
num_epochs = 200 | |
is_save = True | |
is_load = True | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if task == "classify": | |
model_file = "h%d_cl_model.pth" % hidden_nodes | |
elif task == "re_con": | |
model_file = "re_%d_cl_model.pth" % hidden_nodes | |
class AutoEncoder(nn.Module): | |
def __init__(self, img_size, hidden_nodes): | |
super().__init__() | |
self.input_node = img_size * img_size | |
self.hidden_node = hidden_nodes | |
self.linear = nn.Linear(self.input_node, self.hidden_node) | |
self.relu = nn.ReLU(inplace=True) | |
self.sigmoid = nn.Sigmoid() | |
self.autoencoder = nn.Linear(self.hidden_node, self.input_node) | |
self.classifier = nn.Linear(self.hidden_node, 10) | |
def forward(self, x): | |
x = x.view(x.size(0), 784) | |
x = self.relu(self.linear(x)) | |
image = self.autoencoder(x) | |
output = self.classifier(x) | |
return self.sigmoid(image), output | |
class MNISTDataset(Dataset): | |
def __init__(self, data_dir, train=True): | |
super().__init__() | |
self.images, self.labels = self.get_data(data_dir, train) | |
mean = (0.1307, ) | |
std = (0.3081, ) | |
self.train_trans = transforms.Compose([ | |
transforms.RandomRotation((0, 10), fill=(0, )), | |
transforms.ToTensor(), | |
transforms.Normalize(mean, std) | |
]) | |
self.test_trans = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean, std) | |
]) | |
self.trans = (self.train_trans if train else self.test_trans) | |
def get_data(self, data_dir, train): | |
images = [] | |
labels = [] | |
if train == True: | |
dataset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True) | |
for img, label in dataset: | |
images.append(img) | |
labels.append(label) | |
elif train == False: | |
dataset = torchvision.datasets.MNIST(root=data_dir, train=False, download=True) | |
for img, label in dataset: | |
images.append(img) | |
labels.append(label) | |
return images, labels | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, index): | |
img = self.images[index] | |
img = self.trans(img) | |
label = np.array(self.labels[index], dtype=np.float) | |
return img, label | |
class Train(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.train_loader = DataLoader(MNISTDataset(data_dir, train=True), | |
batch_size=batch_size, | |
shuffle=True, num_workers=4) | |
self.test_loader = DataLoader(MNISTDataset(data_dir, train=False), | |
batch_size=batch_size, | |
shuffle=True, num_workers=4) | |
self.model = AutoEncoder(img_size = img_size, | |
hidden_nodes=hidden_nodes).to(device) | |
self.classifier_loss = nn.CrossEntropyLoss() | |
self.MSE = nn.MSELoss() | |
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4) | |
self.schedular = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.92) | |
self.iter = 0 | |
if is_load: | |
checkpoint = torch.load(model_file) | |
self.model.load_state_dict(checkpoint["model"]) | |
self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
def train(self): | |
train_loop = tqdm.tqdm(self.train_loader) | |
for e in range(num_epochs): | |
total_loss = 0 | |
for img, label in train_loop: | |
img = img.to(device) | |
label = torch.tensor(label, dtype = torch.long).to(device) | |
reconstructed_img, scores = self.model(img) | |
reconstructed_img = reconstructed_img.view(-1, 1, 28, 28) | |
if task == "re_con": | |
loss = self.MSE(reconstructed_img, img) | |
elif task == "classify": | |
loss = self.classifier_loss(scores, label) | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
total_loss += loss.item() | |
print("Epoch {} | total loss {:.4f}".format(e, total_loss / len(self.train_loader))) | |
if (e % 10 == 0 and e !=0): | |
self.schedular.step() | |
print(self.schedular.get_lr()) | |
# change according to your experiment | |
if task == "classify": | |
if (e > 100 and e % 10 == 0): | |
self.test() | |
if is_save: | |
checkpoint = {} | |
checkpoint["model"] = self.model.state_dict() | |
checkpoint["optimizer"] = self.optimizer.state_dict() | |
torch.save(checkpoint, model_file) | |
print("saving model") | |
if task == "re_con": | |
self.do_reconstruction() | |
def test(self): | |
self.model.eval() | |
result = 0 | |
for img, label in self.test_loader: | |
img = img.to(device) | |
_, scores = self.model(img) | |
output = torch.argmax(scores, dim=1) | |
output = output.cpu().detach().numpy() | |
result += sum(output[i] == label[i] for i in range(len(label))) | |
result = result.cpu().detach().numpy() | |
acc = result / (len(self.test_loader) * batch_size) | |
print("accuracy: {:.4f}".format(acc)) | |
self.model.train() | |
def plot_acc_vs_hidden_nodes(self, hidden_nodes, acc): | |
plt.plot(hidden_nodes, acc, 'g', label='Test accuracy') | |
plt.title('Test accuracy on MNIST dataset') | |
plt.xlabel('Number of hidden nodes in the autoencoder') | |
plt.ylabel('Accuracy') | |
plt.legend() | |
plt.show() | |
def draw_confusion_matrix(self): | |
self.model.eval() | |
true_labels = [] | |
pred_labels = [] | |
for img, label in self.test_loader: | |
img = img.to(device) | |
_, scores = self.model(img) | |
output = torch.argmax(scores, dim=1).cpu().detach().tolist() | |
pred_labels += output | |
true_labels += label | |
# make confusion matrix | |
c_matrix = confusion_matrix(y_true=true_labels, y_pred=pred_labels) | |
plt.figure(figsize = (10, 10)) | |
sns. set(font_scale=1.4) | |
sns.heatmap(c_matrix, annot=True, fmt = 'g', linewidths=.5) | |
# labels, title | |
plt.xlabel('Predicted Label', fontsize=10, labelpad=11) | |
plt.ylabel('True Label', fontsize=10) | |
b, t = plt.ylim() # discover the values for bottom and top | |
b += 0.5 # Add 0.5 to the bottom | |
t -= 0.5 # Subtract 0.5 from the top | |
plt.ylim(b, t) # update the ylim(bottom, top) values | |
plt.show() | |
def display_weights_as_image(self): | |
print("saving weight images") | |
img_dir = os.path.join("w_images", str(hidden_nodes)) | |
os.makedirs(img_dir, exist_ok=True) | |
for index in range(hidden_nodes): | |
img = self.model.linear.weight[index] | |
img = img.view(28, 28).cpu().detach().numpy() | |
plt.imshow(img, cmap="gray") | |
plt.savefig(os.path.join(img_dir, str(index) + ".png")) | |
#plt.show() | |
def do_reconstruction(self): | |
print("reconstructing images ...") | |
iter = self.train_loader.__iter__() | |
img_dir = os.path.join(saved_image_dir, str(hidden_nodes)) | |
os.makedirs(img_dir, exist_ok=True) | |
while True: | |
img, label = iter.__next__() | |
re_img, _ = self.model(img.to(device)) | |
label = label.detach().tolist() | |
if set(label) == set(float(i) for i in range(0, 10)): | |
digit_indexs = [label.index(float(i)) for i in range(0, 10)] | |
break | |
for index in digit_indexs: | |
img = re_img[index] | |
img = img.cpu().detach().numpy() | |
img = img.reshape(28, 28) | |
plt.imshow(img, cmap="gray") | |
plt.savefig(os.path.join(img_dir, str(index) + ".png")) | |
if __name__ == "__main__": | |
t = Train() | |
t.train() | |
#hidden_nodes = [5, 10, 20, 30, 40, 50, 60] | |
#acc = [0.9010, 0.9463, 0.9685, 0.9739, 0.9763, 0.9782, 0.9807] | |
#t.plot_acc_vs_hidden_nodes(hidden_nodes, acc) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment