Skip to content

Instantly share code, notes, and snippets.

@patharanordev
Last active March 27, 2021 09:58
Show Gist options
  • Save patharanordev/f01de6c00a229c77cc1e1260901dd4d2 to your computer and use it in GitHub Desktop.
Save patharanordev/f01de6c00a229c77cc1e1260901dd4d2 to your computer and use it in GitHub Desktop.
Using matplotlib to plot image grid M x N

Plot image grid M x N via matplotlib

The example below, try to show predicted result image & label on VGG19 model, I have set grid size to N x 5 :

  • multi-rows
  • 5-columns

I have directory structure like this

content
`- product
   |- train
   |  |- class_name1
   |  |  |- image_file_1
   |  |  :  ...
   |  |  `- image_file_n
   |  `- class_name2
   |     |- image_file_1
   |     :  ...
   |     `- image_file_n
   |- valid
   |  |- class_name1
   |  |  |- image_file_1
   |  |  :  ...
   |  |  `- image_file_n
   |  `- class_name2
   |     |- image_file_1
   |     :  ...
   |     `- image_file_n
   `- test
      |- class_name1
      |  |- image_file_1
      |  :  ...
      |  `- image_file_n
      `- class_name2
         |- image_file_1
         :  ...
         `- image_file_n

Let's print image grid :

import matplotlib.pyplot as plt
import math

plt.rcParams['figure.figsize'] = (20,20)

TEST_DIR = '/content/product/test'
class_in_dir = os.listdir(TEST_DIR)
for class_name in class_in_dir:
  class_dir = os.path.join(TEST_DIR, class_name)
  
  # Prevent for macOS
  if class_dir.find('.DS_Store') == -1:

    num_col = 5
    num_row = math.ceil(len(class_in_dir)/num_col)
    count_col = 0

    for img_path in os.listdir(class_dir):

      plt.subplot(num_row, num_col, count_col+1)
      
      img = load_img(os.path.join(TEST_DIR, class_name, img_path), target_size=(IMAGE_SIZE, IMAGE_SIZE))
      img_arr = img_to_array(img)
      img_arr = img_arr.reshape((1, img_arr.shape[0], img_arr.shape[1], img_arr.shape[2]))
      pred = model.predict(img_arr)
      max_idx = np.argmax(pred[0])
      pred_label = class_names[max_idx]
      
      plt.imshow(img)
      plt.title('{} {:.2f}%'.format(pred_label, pred[0][max_idx]*100))
      plt.axis('off')
      
      count_col = count_col + 1

      if count_col % num_col == 0:
        count_col = 0
        plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment