Skip to content

Instantly share code, notes, and snippets.

@alexlee-gk
Created October 6, 2018 20:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexlee-gk/6149fbc13f1935098b832632d7afa756 to your computer and use it in GitHub Desktop.
Save alexlee-gk/6149fbc13f1935098b832632d7afa756 to your computer and use it in GitHub Desktop.
Convert a feature map into an image.
import matplotlib.pyplot as plt
import numpy as np
def vis_square(data, grid_shape=None, padsize=1, padval=0, cmap=None, data_min=None, data_max=None):
data_min = data_min if data_min is not None else data.min()
data_max = data_max if data_max is not None else data.max()
data = (data - data_min) / (data_max - data_min)
lead_shape = data.shape[:-3]
height, width, num_channels = data.shape[-3:]
if grid_shape is None:
# force the number of filters to be square
nrows = ncols = int(np.ceil(np.sqrt(num_channels)))
else:
nrows, ncols = grid_shape
assert num_channels <= nrows * ncols
if cmap is None:
cmap = plt.get_cmap('viridis')
data = (data * 255).astype(np.uint8)
data = np.array(cmap.colors)[data]
data = (data * 255).astype(np.uint8)
padding = [(0, 0)] * (data.ndim - 4) + [
(0, padsize), (0, padsize), (0, nrows * ncols - num_channels)] + [(0, 0)]
data = np.pad(data, padding, mode='constant', constant_values=padval)
shape = lead_shape + (height + padsize, width + padsize, nrows, ncols, 3)
data = np.reshape(data, shape)
data = np.transpose(
data, tuple(range(len(lead_shape))) + (-3, -5, -2, -4, -1))
shape = lead_shape + (nrows * (height+padsize), ncols * (width+padsize), 3)
data = np.reshape(data, shape)
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment