Skip to content

Instantly share code, notes, and snippets.

@RaphaelMeudec
Last active May 14, 2020 18:44
Show Gist options
  • Save RaphaelMeudec/31b7bba0b972ec6ec80ed131a59c5b3f to your computer and use it in GitHub Desktop.
Save RaphaelMeudec/31b7bba0b972ec6ec80ed131a59c5b3f to your computer and use it in GitHub Desktop.
Visualize convolutional kernels with Tensorflow 2.0
import numpy as np
import tensorflow as tf
# Layer name to inspect
layer_name = 'block3_conv1'
epochs = 100
step_size = 1.
filter_index = 0
# Create a connection between the input and the target layer
model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True)
submodel = tf.keras.models.Model([model.inputs[0]], [model.get_layer(layer_name).output])
# Initiate random noise
input_img_data = np.random.random((1, 224, 224, 3))
input_img_data = (input_img_data - 0.5) * 20 + 128.
# Cast random noise from np.float64 to tf.float32 Variable
input_img_data = tf.Variable(tf.cast(input_img_data, tf.float32))
# Iterate gradient ascents
for _ in range(epochs):
with tf.GradientTape() as tape:
outputs = submodel(input_img_data)
loss_value = tf.reduce_mean(outputs[:, :, :, filter_index])
grads = tape.gradient(loss_value, input_img_data)
normalized_grads = grads / (tf.sqrt(tf.reduce_mean(tf.square(grads))) + 1e-5)
input_img_data.assign_add(normalized_grads * step_size)
@melwazir
Copy link

melwazir commented Dec 13, 2019

Excuse my stupid question, but how do we actually see/save the image(s)?

@RaphaelMeudec
Copy link
Author

At the end of the loop, input_img_data is a 4D tensor holding the generated image. What you want to do is convert it to a numpy with .numpy() and visualizing it with matplotlib for example

@konradsemsch
Copy link

Is this the expected outcome in this case?

image

@RaphaelMeudec
Copy link
Author

For matplotlib to perform well, you want either to normalize values between 0 and 1, or convert the image to int. As the warnings says, it's clipping values from 0-255 range to 0-1 which makes the image so poor

@lballore
Copy link

@konradsemsch I solved it by converting input_img_data this way:

input_img_data = input_img_data.numpy().astype(np.uint8)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment