Skip to content

Instantly share code, notes, and snippets.

@RITIK-12
Created May 8, 2021 18:18
Show Gist options
  • Save RITIK-12/e30a4d772cd5036038e23c966192f584 to your computer and use it in GitHub Desktop.
Save RITIK-12/e30a4d772cd5036038e23c966192f584 to your computer and use it in GitHub Desktop.
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
out = utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment