Skip to content

Instantly share code, notes, and snippets.

@you359
Last active November 5, 2020 16:19
Show Gist options
  • Save you359/45bad1ae12e66e739afc9386ae4b2794 to your computer and use it in GitHub Desktop.
Save you359/45bad1ae12e66e739afc9386ae4b2794 to your computer and use it in GitHub Desktop.
from keras.applications.resnet50 import ResNet50
from keras.applications.resnet50 import preprocess_input, decode_predictions
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input, decode_predictions
from keras.preprocessing import image
import keras.backend as K
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
def load_image(path, target_size=(224, 224)):
x = image.load_img(path, target_size=target_size)
x = image.img_to_array(x)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
return x
def generate_gradcam(img_tensor, model, class_index, activation_layer):
model_input = model.input
# y_c : class_index에 해당하는 CNN 마지막 layer op(softmax, linear, ...)의 입력
y_c = model.outputs[0].op.inputs[0][0, class_index]
# A_k: activation conv layer의 출력 feature map
A_k = model.get_layer(activation_layer).output
# model의 입력에 대해서,
# activation conv layer의 출력(A_k)과
# 최종 layer activation 입력(y_c)의 A_k에 대한 gradient,
# 모델의 최종 출력(prediction) 계산
get_output = K.function([model_input], [A_k, K.gradients(y_c, A_k)[0]])
[conv_output, grad_val] = get_output([img_tensor])
# batch size가 포함되어 shape가 (1, width, height, k)이므로
# (width, height, k)로 shape 변경
# 여기서 width, height는 activation conv layer인 A_k feature map의 width와 height를 의미함
conv_output = conv_output[0]
grad_val = grad_val[0]
# global average pooling 연산
# gradient의 width, height에 대해 평균을 구해서(1/Z) weights(a^c_k) 계산
weights = np.mean(grad_val, axis=(0, 1))
# activation conv layer의 출력 feature map(conv_output)과
# class_index에 해당하는 weights(a^c_k)를 k에 대응해서 weighted combination 계산
# feature map(conv_output)의 (width, height)로 초기화
grad_cam = np.zeros(dtype=np.float32, shape=conv_output.shape[0:2])
for k, w in enumerate(weights):
grad_cam += w * conv_output[:, :, k]
# 계산된 weighted combination 에 ReLU 적용
grad_cam = np.maximum(grad_cam, 0)
return grad_cam, weights
def generate_cam(img_tensor, model, class_index, activation_layer):
model_input = model.input
# A_k : 마지막 conv layer의 출력 feature map
A_k = model.get_layer(activation_layer).output
# model의 입력에 대해서,
# 마지막 conv layer의 출력(A_k)과
# 모델의 최종 출력(prediction) 계산
get_output = K.function([model_input], [A_k])
[conv_output] = get_output([img_tensor])
# batch size가 포함되어 shape가 (1, width, height, k)이므로
# (width, height, k)로 shape 변경
# 여기서 width, height는 마지막 conv layer인 A_k feature map(conv output)의 width와 height를 의미함
conv_output = conv_output[0]
# softmax(+ dense) layer와 GAP layer 사이의 weight matrix에서
# class_index에 해당하는 weights(a^c_k = w^c_k) 계산
# ex) w^2_1, w^2_2, w^2_3, ..., w^2_k
weights = model.layers[-1].get_weights()[0][:, class_index]
# 마지막 conv layer의 출력 feature map(conv_output)과
# class_index에 해당하는 weights(w^c_k)를 k에 대응해서 weighted combination을 구함
# feature map(last_conv_output)의 (width, height)로 초기화
cam = np.zeros(dtype=np.float32, shape=conv_output.shape[0:2])
for k, w in enumerate(weights):
cam += w * conv_output[:, :, k]
return cam, weights
if __name__ == "__main__":
img_width = 224
img_height = 224
model = ResNet50(weights='imagenet')
# model = InceptionV3(weights='imagenet')
print(model.summary())
img_path = '../image/elephant.jpg'
img = load_image(path=img_path, target_size=(img_width, img_height))
preds = model.predict(img)
predicted_class = preds.argmax(axis=1)[0]
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print("predicted top1 class:", predicted_class)
print('Predicted:', decode_predictions(preds, top=1)[0])
# Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]
conv_name = 'activation_49'
# conv_name = 'mixed10'
grad_cam, grad_val = generate_gradcam(img, model, predicted_class, conv_name)
cam, cam_weight = generate_cam(img, model, predicted_class, conv_name)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (img_width, img_height))
cam = cv2.resize(cam, (img_width, img_height))
plt.figure(0)
plt.imshow(img)
plt.imshow(cam, cmap="jet", alpha=.5)
plt.axis('off')
grad_cam = cv2.resize(grad_cam, (img_width, img_height))
plt.figure(1)
plt.imshow(img)
plt.imshow(grad_cam, cmap="jet", alpha=.5)
plt.axis('off')
plt.show()
import keras
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input, decode_predictions
from keras.preprocessing import image
import keras.backend as K
import tensorflow as tf
from tensorflow.python.framework import ops
import numpy as np
import matplotlib.pyplot as plt
import cv2
from utils import deprocess_image
def load_image(path, target_size=(224, 224)):
x = image.load_img(path, target_size=target_size)
x = image.img_to_array(x)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
return x
def register_gradient():
if "GuidedBackProp" not in ops._gradient_registry._registry:
@ops.RegisterGradient("GuidedBackProp")
def _GuidedBackProp(op, grad):
dtype = op.inputs[0].dtype
return grad * tf.cast(grad > 0., dtype) * \
tf.cast(op.inputs[0] > 0., dtype)
def modify_backprop(model, name):
g = tf.get_default_graph()
with g.gradient_override_map({'Relu': name}):
# get layers that have an activation
layer_dict = [layer for layer in model.layers[1:]
if hasattr(layer, 'activation')]
# replace relu activation
for layer in layer_dict:
if layer.activation == keras.activations.relu:
layer.activation = tf.nn.relu
# re-instanciate a new model
new_model = VGG16(weights='imagenet')
return new_model
def guided_backpropagation(img_tensor, model, activation_layer):
model_input = model.input
layer_output = model.get_layer(activation_layer).output
# one_output = layer_output[:, :, :, 256]
max_output = K.max(layer_output, axis=3)
get_output = K.function([model_input], [K.gradients(max_output, model_input)[0]])
# get_output = K.function([model_input], [K.gradients(one_output, model_input)[0]])
saliency = get_output([img_tensor])
return saliency[0]
def generate_gradcam(img_tensor, model, class_index, activation_layer):
model_input = model.input
# y_c : class_index에 해당하는 CNN 마지막 layer op(softmax, linear, ...)의 입력
y_c = model.outputs[0].op.inputs[0][0, class_index]
# y_c = model.outputs[0].op.inputs[0].op.inputs[0][0, class_index]
# A_k: activation conv layer의 출력 feature map
A_k = model.get_layer(activation_layer).output
# model의 입력에 대해서,
# activation conv layer의 출력(A_k)과
# 최종 layer activation 입력(y_c)의 A_k에 대한 gradient,
# 모델의 최종 출력(prediction) 계산
get_output = K.function([model_input], [A_k, K.gradients(y_c, A_k)[0], model.output])
[conv_output, grad_val, prediction] = get_output([img_tensor])
# batch size가 포함되어 shape가 (1, width, height, k)이므로
# (width, height, k)로 shape 변경
# 여기서 width, height는 activation conv layer인 A_k feature map의 width와 height를 의미함
conv_output = conv_output[0]
grad_val = grad_val[0]
# global average pooling 연산
# gradient의 width, height에 대해 평균을 구해서(1/Z) weights(a^c_k) 계산
weights = np.mean(grad_val, axis=(0, 1))
# activation conv layer의 출력 feature map(conv_output)과
# class_index에 해당하는 weights(a^c_k)를 k에 대응해서 weighted combination 계산
# feature map(conv_output)의 (width, height)로 초기화
grad_cam = np.zeros(dtype=np.float32, shape=conv_output.shape[0:2])
for k, w in enumerate(weights):
grad_cam += w * conv_output[:, :, k]
# 계산된 weighted combination 에 ReLU 적용
grad_cam = np.maximum(grad_cam, 0)
return grad_cam, weights
if __name__ == "__main__":
img_width = 224
img_height = 224
model = VGG16(weights='imagenet')
print(model.summary())
img_path = '../image/cat_dog.jpg'
img = load_image(path=img_path, target_size=(img_width, img_height))
preds = model.predict(img)
predicted_class = preds.argmax(axis=1)[0]
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print("predicted top1 class:", predicted_class)
print('Predicted:', decode_predictions(preds, top=1)[0])
# Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]
conv_name = 'block5_conv3'
grad_cam, grad_val = generate_gradcam(img, model, predicted_class, conv_name)
register_gradient()
guided_model = modify_backprop(model, 'GuidedBackProp')
gradient = guided_backpropagation(img, guided_model, conv_name)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (img_width, img_height))
grad_cam = cv2.resize(grad_cam, (img_width, img_height))
plt.figure(0)
plt.imshow(img)
plt.imshow(grad_cam, cmap="jet", alpha=.5)
plt.axis('off')
plt.figure(1)
plt.imshow(deprocess_image(gradient))
plt.axis('off')
plt.figure(2)
plt.imshow(grad_cam)
plt.axis('off')
guided_gradcam = gradient * grad_cam[..., np.newaxis]
plt.figure(3)
plt.imshow(deprocess_image(guided_gradcam))
plt.axis('off')
plt.show()
cv2.imshow('heatmap', grad_cam)
cv2.waitKey()
cv2.destroyAllWindows()
import keras.backend as K
import numpy as np
def deprocess_image(x):
'''
Same normalization as in:
https://github.com/fchollet/keras/blob/master/examples/conv_filter_visualization.py
'''
if np.ndim(x) > 3:
x = np.squeeze(x)
# normalize tensor: center on 0., ensure std is 0.1
x -= x.mean()
x /= (x.std() + 1e-5)
x *= 0.1
# clip to [0, 1]
x += 0.5
x = np.clip(x, 0, 1)
# convert to RGB array
x *= 255
if K.image_dim_ordering() == 'th':
x = x.transpose((1, 2, 0))
x = np.clip(x, 0, 255).astype('uint8')
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment