Skip to content

Instantly share code, notes, and snippets.

@wookayin
Last active March 19, 2019 20:49
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wookayin/db1800194c0e44b12316696376ecd01a to your computer and use it in GitHub Desktop.
Save wookayin/db1800194c0e44b12316696376ecd01a to your computer and use it in GitHub Desktop.
IPython notebook snippet for plotting multiple images in a grid.
# in a courtesy of Caffe's filter visualization example
# http://nbviewer.jupyter.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb
def imshow_grid(data, height=None, width=None, normalize=False, padsize=1, padval=0):
'''
Take an array of shape (N, H, W) or (N, H, W, C)
and visualize each (H, W) image in a grid style (height x width).
'''
if normalize:
data -= data.min()
data /= data.max()
N = data.shape[0]
if height is None:
if width is None:
height = int(np.ceil(np.sqrt(N)))
else:
height = int(np.ceil( N / float(width) ))
if width is None:
width = int(np.ceil( N / float(height) ))
assert height * width >= N
# append padding
padding = ((0, (width*height) - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3)
data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))
# tile the filters into an image
data = data.reshape((height, width) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
data = data.reshape((height * data.shape[1], width * data.shape[3]) + data.shape[4:])
plt.imshow(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment