Skip to content

Instantly share code, notes, and snippets.

@mirceast
Created May 22, 2019 14:16
Show Gist options
  • Save mirceast/3f15c50957d2a34a5cfc87b62d947027 to your computer and use it in GitHub Desktop.
Save mirceast/3f15c50957d2a34a5cfc87b62d947027 to your computer and use it in GitHub Desktop.
Trasnfer Learning 2
# Helper function for displaying images
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
# Un-normalize the images
inp = std * inp + mean
# Clip just in case
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment