Skip to content

Instantly share code, notes, and snippets.

@vashineyu
Created July 25, 2019 08:38
Show Gist options
  • Save vashineyu/d80ab3c946ca212311d30c58c390e88e to your computer and use it in GitHub Desktop.
Save vashineyu/d80ab3c946ca212311d30c58c390e88e to your computer and use it in GitHub Desktop.
def gradcam_plus(model, im, class_select, layer, image_size, preproc_fn, alpha=0.6, filter_threshold=0.5):
"""GradCAM method for visualizing input saliency.
Args:
model: keras model
im: single image (with only RGB, [H,W,C])
class_select: class to show
layer: layer name
image_size: tuple of image H,W
preproc_fn: preprocessing function
alpha: alpha
Returns:
gradient-class-activation-map
"""
H, W = image_size[0], image_size[1]
image = im.copy()
if len(image) != 4:
image = image[np.newaxis, :, :, :]
image_original = image[0].astype("uint8")
image = preproc_fn(image.astype("float32"))
y_c = model.output[0, class_select]
conv_output = model.get_layer(layer).output
grads = K.gradients(y_c, conv_output)[0]
def nth_gradient_derivative(Sc, n):
return K.exp(Sc) * K.pow(grads, n)
derv1 = nth_gradient_derivative(y_c, 1)
derv2 = nth_gradient_derivative(y_c, 2)
derv3 = nth_gradient_derivative(y_c, 3)
gradient_function = K.function([model.input], [conv_output, derv1, derv2, derv3])
with tf.device("/gpu:0"):
A, d1, d2, d3 = gradient_function([image])
A, d1, d2, d3 = A[0], d1[0], d2[0], d3[0] # from (n, h, w, c) --> (h, w, c)
grad_weight_alpha = d2 / (2.0 * d2 + (np.sum(A, axis=(0, 1)) * d3) + 1e-8)
wc = grad_weight_alpha * np.clip(d1, a_min=0, a_max=d1.max())
cam = np.dot(A, np.sum(wc, axis=(0, 1)))
cam = cv2.resize(cam, (H, W), cv2.INTER_CUBIC)
cam = np.maximum(cam, 0)
cam = cam / cam.max()
# Filter
cam[cam < filter_threshold] = 0
# apply colormap
mapping = cv2.applyColorMap(np.uint8(255 * (1 - cam)), cv2.COLORMAP_JET)
mapping = np.concatenate((mapping, ((mapping.max(axis=-1) - 128) * 255 * alpha)[:, :, np.newaxis]), axis=-1)
background = Image.fromarray(image_original)
foreground = Image.fromarray(mapping.astype('uint8'))
background.paste(foreground, (0, 0), foreground)
return cam, background
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment