Skip to content

Instantly share code, notes, and snippets.

@nandavelugoti
Created April 17, 2020 20:09
Show Gist options
  • Save nandavelugoti/ce752a9db8aff94d6a631aea0332e28d to your computer and use it in GitHub Desktop.
Save nandavelugoti/ce752a9db8aff94d6a631aea0332e28d to your computer and use it in GitHub Desktop.
Visualize activation filters
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 15 17:20:25 2020
@author: Nanda Velugoti
"""
import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras.models import load_model, Model
img_path = 'train/Cat/123.jpg'
img = image.load_img(img_path, target_size = (150, 150))
img_tensor = image.img_to_array(img)
img_tensor = np.expand_dims(img_tensor, axis=0)
img_tensor /= 255
plt.imshow(img_tensor[0])
plt.show()
model = load_model('cats_dogs_6.h5')
layer_outputs = [layer.output for layer in model.layers[:8]]
activation_model = Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(img_tensor)
first_layer_activation = activations[0]
plt.matshow(first_layer_activation[0, :, :, 4], cmap='viridis')
plt.matshow(first_layer_activation[0, :, :, 7], cmap='viridis')
layer_names = []
for layer in model.layers[:8]:
layer_names.append(layer.name)
images_per_row = 16
for layer_name, layer_activation in zip(layer_names, activations):
n_features = layer_activation.shape[-1]
size = layer_activation.shape[1]
n_cols = n_features // images_per_row
display_grid = np.zeros((size*n_cols, images_per_row*size))
for col in range(n_cols):
for row in range(images_per_row):
channel_image = layer_activation[0, :, :, col*images_per_row + row]
channel_image -= channel_image.mean()
channel_image -= channel_image.std()
channel_image *= 64
channel_image += 128
channel_image = np.clip(channel_image, 0, 255).astype('uint8')
display_grid[col*size : (col+1)*size, row*size: (row+1)*size] = channel_image
scale = 1./size
plt.figure(figsize=(scale*display_grid.shape[1], scale * display_grid.shape[0]))
plt.title(layer_name)
plt.grid(False)
plt.imshow(display_grid, aspect='auto', cmap='viridis')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment