Skip to content

Instantly share code, notes, and snippets.

@mayankgrwl97
Created October 2, 2020 07:15
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save mayankgrwl97/1c1eba10091cc5d4a2cd549574da4bb5 to your computer and use it in GitHub Desktop.
# Show Tensor Images utility function
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
'''
Function for visualizing images: Given a tensor of images, number of images, and
size per image, plots and prints the images in a uniform grid.
'''
image_unflat = image_tensor.detach().cpu().view(-1, *size)
image_grid = make_grid(image_unflat[:num_images], nrow=5)
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment