Skip to content

Instantly share code, notes, and snippets.

@BeBeBerr
Last active June 12, 2021 09:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BeBeBerr/5af065430dece675f2b585f260108998 to your computer and use it in GitHub Desktop.
Save BeBeBerr/5af065430dece675f2b585f260108998 to your computer and use it in GitHub Desktop.
Grad-CAM with MobileNet v2
import torch
import numpy as np
from torchvision import datasets, transforms, models
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
class GradCAM():
def __init__(self, model, layer_index=-6):
# -6 is the last conv2d layer for mobilenet v2
self.model = model
self.layer_index = layer_index
self.register_hooks()
def _forward_hook(self, module, input, output):
self.feature_map = output
def _backward_hook(self, module, grad_input, grad_output):
self.feature_map_grad = grad_output[0] # grad_output is a tensor
def register_hooks(self):
_, layer = list(self.model.named_modules())[self.layer_index]
layer.register_forward_hook(self._forward_hook)
layer.register_backward_hook(self._backward_hook)
def __call__(self, prediction, class_index):
self.model.zero_grad()
score = prediction[0, class_index]
score.backward()
alpha = self.feature_map_grad.mean(dim=(-1, -2), keepdim=True)
heatmap = self.feature_map * alpha
heatmap = heatmap.sum(1)
heatmap = F.relu(heatmap)
return heatmap
def main():
model = models.MobileNetV2(num_classes=102)
checkpoint = torch.load('checkpoints/' + 'baseline.pth.tar')
model_dict = checkpoint['state_dict']
model.load_state_dict(model_dict)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_origin = Image.open('image_0042.jpg')
image = transform(image_origin)
image = torch.unsqueeze(image, 0)
grad_cam = GradCAM(model)
model.eval()
output = model(image)
index = output.argmax()
heatmap = grad_cam(output, index)[0].detach().numpy()
plt.imsave('heatmap_small.jpg', heatmap, cmap='rainbow')
heatmap = cv2.resize(heatmap, (224, 224), interpolation=cv2.INTER_CUBIC)
plt.imsave('heatmap.jpg', heatmap, cmap='rainbow')
heatmap_image = Image.open('heatmap.jpg')
heatmap_image = heatmap_image.resize((224, 224))
image_origin = image_origin.resize((224, 224))
blend = Image.blend(image_origin, heatmap_image, 0.5)
blend.save('blend.jpg')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment