Created
July 17, 2018 05:18
-
-
Save philoniare/85d5c8319d441ae345f693a63ab4de16 to your computer and use it in GitHub Desktop.
helper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
from torch import nn, optim | |
from torch.autograd import Variable | |
def test_network(net, trainloader): | |
criterion = nn.MSELoss() | |
optimizer = optim.Adam(net.parameters(), lr=0.001) | |
dataiter = iter(trainloader) | |
images, labels = dataiter.next() | |
# Create Variables for the inputs and targets | |
inputs = Variable(images) | |
targets = Variable(images) | |
# Clear the gradients from all Variables | |
optimizer.zero_grad() | |
# Forward pass, then backward pass, then update weights | |
output = net.forward(inputs) | |
loss = criterion(output, targets) | |
loss.backward() | |
optimizer.step() | |
return True | |
def imshow(image, ax=None, title=None, normalize=True): | |
"""Imshow for Tensor.""" | |
if ax is None: | |
fig, ax = plt.subplots() | |
image = image.numpy().transpose((1, 2, 0)) | |
if normalize: | |
mean = np.array([0.485, 0.456, 0.406]) | |
std = np.array([0.229, 0.224, 0.225]) | |
image = std * image + mean | |
image = np.clip(image, 0, 1) | |
ax.imshow(image) | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
ax.spines['left'].set_visible(False) | |
ax.spines['bottom'].set_visible(False) | |
ax.tick_params(axis='both', length=0) | |
ax.set_xticklabels('') | |
ax.set_yticklabels('') | |
return ax | |
def view_recon(img, recon): | |
''' Function for displaying an image (as a PyTorch Tensor) and its | |
reconstruction also a PyTorch Tensor | |
''' | |
fig, axes = plt.subplots(ncols=2, sharex=True, sharey=True) | |
axes[0].imshow(img.numpy().squeeze()) | |
axes[1].imshow(recon.data.numpy().squeeze()) | |
for ax in axes: | |
ax.axis('off') | |
ax.set_adjustable('box-forced') | |
def view_classify(img, ps, version="MNIST"): | |
''' Function for viewing an image and it's predicted classes. | |
''' | |
ps = ps.data.numpy().squeeze() | |
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2) | |
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze()) | |
ax1.axis('off') | |
ax2.barh(np.arange(10), ps) | |
ax2.set_aspect(0.1) | |
ax2.set_yticks(np.arange(10)) | |
if version == "MNIST": | |
ax2.set_yticklabels(np.arange(10)) | |
elif version == "Fashion": | |
ax2.set_yticklabels(['T-shirt/top', | |
'Trouser', | |
'Pullover', | |
'Dress', | |
'Coat', | |
'Sandal', | |
'Shirt', | |
'Sneaker', | |
'Bag', | |
'Ankle Boot'], size='small'); | |
ax2.set_title('Class Probability') | |
ax2.set_xlim(0, 1.1) | |
plt.tight_layout() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment