Skip to content

Instantly share code, notes, and snippets.

@pratos
Last active December 28, 2017 10:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pratos/b2a118cd9adcf41177ff3eaf17811d48 to your computer and use it in GitHub Desktop.
Save pratos/b2a118cd9adcf41177ff3eaf17811d48 to your computer and use it in GitHub Desktop.
Matplotlib Helper Functions

Creating multiple subplots to display numpy images:

def create_subplots(images, labels):
    plt.subplot()
    for _ in range(0, len(images)):
        
        plt.rcParams["figure.figsize"] = [10,20]
        plt.subplot(len(images)/2,len(images)/5, _+1)
        plt.xticks([]), plt.yticks([])
        plt.imshow(npimg, cmap = 'gray')
        plt.title("Label:{}".format(classes[labels[_]]))
        plt.tight_layout()

Pytorch specific implementation (where we get Pytorch Tensors as inputs, needs to be checked for color images):

def imshow(images, labels, channels=1):
    """For single channel inputs
    """
    if channels == 1:
        plt.subplot()
        for _ in range(0, len(images)):
            
            img = images[_]/ 2 + 0.5
            npimg = img.numpy()
            npimg = npimg.reshape(npimg.shape[1],npimg.shape[2])
            
            plt.rcParams["figure.figsize"] = [10,20]
            plt.subplot(len(images)/2,len(images)/5, _+1)
            plt.xticks([]), plt.yticks([])
            plt.imshow(npimg, cmap = 'gray')
            plt.title("Label:{}".format(classes[labels[_]]))
            plt.tight_layout()
    elif channels == 0:
        img = img / 2 + 0.5     # unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        print("Label:{}".format(classes[labels[0]]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment