Skip to content

Instantly share code, notes, and snippets.

@kenorb
Forked from mayankgrwl97/show_tensor_images.py
Created December 4, 2020 18:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kenorb/7be08fe233840e313199b1cc738f48de to your computer and use it in GitHub Desktop.
Save kenorb/7be08fe233840e313199b1cc738f48de to your computer and use it in GitHub Desktop.
# Show Tensor Images utility function
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
'''
Function for visualizing images: Given a tensor of images, number of images, and
size per image, plots and prints the images in a uniform grid.
'''
image_unflat = image_tensor.detach().cpu().view(-1, *size)
image_grid = make_grid(image_unflat[:num_images], nrow=5)
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment