Skip to content

Instantly share code, notes, and snippets.

@RaphaelMeudec
Created July 18, 2019 15:11
Show Gist options
  • Star 32 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save RaphaelMeudec/e9a805fa82880876f8d89766f0690b54 to your computer and use it in GitHub Desktop.
Save RaphaelMeudec/e9a805fa82880876f8d89766f0690b54 to your computer and use it in GitHub Desktop.
Grad CAM implementation with Tensorflow 2
import cv2
import numpy as np
import tensorflow as tf
IMAGE_PATH = './cat.jpg'
LAYER_NAME = 'block5_conv3'
CAT_CLASS_INDEX = 281
img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
img = tf.keras.preprocessing.image.img_to_array(img)
model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True)
grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(LAYER_NAME).output, model.output])
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(np.array([img]))
loss = predictions[:, CAT_CLASS_INDEX]
output = conv_outputs[0]
grads = tape.gradient(loss, conv_outputs)[0]
gate_f = tf.cast(output > 0, 'float32')
gate_r = tf.cast(grads > 0, 'float32')
guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
cam = np.ones(output.shape[0: 2], dtype = np.float32)
for i, w in enumerate(weights):
cam += w * output[:, :, i]
cam = cv2.resize(cam.numpy(), (224, 224))
cam = np.maximum(cam, 0)
heatmap = (cam - cam.min()) / (cam.max() - cam.min())
cam = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)
output_image = cv2.addWeighted(cv2.cvtColor(img.astype('uint8'), cv2.COLOR_RGB2BGR), 0.5, cam, 1, 0)
cv2.imwrite('cam.png', output_image)
@Corne173
Copy link

How would you find the feature importance for in terms of the RGB colour channels? Here you get pixel importance, which is a combination of the RGB input. I'm very interest in the answer as it relates to a problem I have where I want to find the grad cam for a multivariate time series

@sneh-debug
Copy link

@RaphaelMeudec Hi. what if the input size is 192,152,4 where 4 is number of 2D slices? how can we obtain for each of the image??

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment