Last active
June 18, 2019 11:10
-
-
Save oeway/f0ed87d3df671b351b533108bf4d9d5d to your computer and use it in GitHub Desktop.
Plot weights of convolutional layer in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%matplotlib inline | |
import pylab as pl | |
import matplotlib.cm as cm | |
import numpy.ma as ma | |
def make_mosaic(imgs, nrows, ncols, border=1): | |
""" | |
Given a set of images with all the same shape, makes a | |
mosaic with nrows and ncols | |
""" | |
nimgs = imgs.shape[0] | |
imshape = imgs.shape[1:] | |
mosaic = ma.masked_all((nrows * imshape[0] + (nrows - 1) * border, | |
ncols * imshape[1] + (ncols - 1) * border), | |
dtype=np.float32) | |
paddedh = imshape[0] + border | |
paddedw = imshape[1] + border | |
for i in xrange(nimgs): | |
row = int(np.floor(i / ncols)) | |
col = i % ncols | |
mosaic[row * paddedh:row * paddedh + imshape[0], | |
col * paddedw:col * paddedw + imshape[1]] = imgs[i] | |
return mosaic | |
# utility functions | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
def nice_imshow(ax, data, vmin=None, vmax=None, cmap=None): | |
"""Wrapper around pl.imshow""" | |
if cmap is None: | |
cmap = cm.jet | |
if vmin is None: | |
vmin = data.min() | |
if vmax is None: | |
vmax = data.max() | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes("right", size="5%", pad=0.05) | |
im = ax.imshow(data, vmin=vmin, vmax=vmax, interpolation='nearest', cmap=cmap) | |
pl.colorbar(im, cax=cax) | |
#pl.imshow(make_mosaic(np.random.random((9, 10, 10)), 3, 3, border=1)) | |
def plot_conv_weights(model, layer): | |
# Visualize weights | |
W = model.layers[layer].W.get_value(borrow=True) | |
W = np.squeeze(W) | |
if len(W.shape) == 4: | |
W = W.reshape((-1,W.shape[2],W.shape[3])) | |
print("W shape : ", W.shape) | |
pl.figure(figsize=(15, 15)) | |
pl.title('conv weights') | |
s = int(np.sqrt(W.shape[0])+1) | |
nice_imshow(pl.gca(), make_mosaic(W, s, s), cmap=cm.binary) | |
# usage | |
plot_conv_weights(model, layer=2) |
i came out with this code, it takes model, and the layer name
def plot_conv_weights(model, layer):
W = model.get_layer(name=layer).get_weights()[0]
if len(W.shape) == 4:
W = np.squeeze(W)
W = W.reshape((W.shape[0], W.shape[1], W.shape[2]*W.shape[3]))
fig, axs = plt.subplots(5,5, figsize=(8,8))
fig.subplots_adjust(hspace = .5, wspace=.001)
axs = axs.ravel()
for i in range(25):
axs[i].imshow(W[:,:,i])
axs[i].set_title(str(i))
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It seems that W is deprecated (Keras 2.0+). How should I change it?