Skip to content

Instantly share code, notes, and snippets.

Last active February 9, 2022 13:12
Show Gist options
  • Save maxidl/b1019ecb27d5887221500d611f21cf0d to your computer and use it in GitHub Desktop.
Save maxidl/b1019ecb27d5887221500d611f21cf0d to your computer and use it in GitHub Desktop.
import sys
import os # noqa
sys.path.insert(0, ".") # noqa
import torch
from utils.styled_plot import plt
from utils.dataset import load_test_image, preprocess_image, normalize_image, convert_idx_to_label
from classifiers.cnn_classifier import ImageNetClassifier
from solutions.explainers import plot_attributions, aggregate_attribution, normalize_attribution
def get_custom_gradient(model, image, absolute=False):
Generates a saliency map based on the input gradient w.r.t. the model's prediction as the target.
model (ImageNetClassifier): Image classification model. Has a 'predict' method that returns the predicted label index for an image.
image (torch.tensor): Single image with shape (1, 3, ?, ?).
absolute (bool): If True, return the absolute value of the gradients. If False, return the signed gradients.
attribution (torch.tensor): The gradient, of the same shape as the image.
Hint: Use torch.autograd.grad . The model is a torch.nn.Module, so you can call model(x) to get the network's outputs.
prediction = model.predict(image)
image.requires_grad = True
output = model(image)
grad = torch.autograd.grad(output[:, prediction], image)[0]
if absolute:
grad = grad.abs()
return grad.detach()
def get_path(image, baseline, num_samples):
Creates the path of images used in Integrated Gradients.
The path starts from a baseline, ends with the actual image and is filled with intermediate samples in between.
Essentially, the path is defined as:
path = baseline + alpha * (image - baseline)
for alpha ranging from 0 to 1.
image (torch.tensor): Single image with shape (1, 3, ?, ?).
baseline (torch.tensor): Baseline image with same shape as image.
num_samples (int): The nuber of samples on the path.
path (List[torch.tensor]): A list of length num_samples, containing the images on the path starting from the baseline (path[0]) and ending with the image (path[-1]).
Hint: Create alphas using torch.linspace.
alphas = torch.linspace(0, 1, num_samples)
path = [baseline + alpha * (image-baseline) for alpha in alphas]
return path
def get_custom_integrated_gradients(model, image, num_samples):
Generate an attribution map based on the Integrated Gradients method, w.r.t. the model's prediction.
Uses zeros (black image) as the baseline, that are normalized using 'normalize_image'.
model (ImageNetClassifier): Image classification model. Has a 'predict' method that returns the predicted label index for an image.
image (torch.tensor): Single image with shape (1, 3, ?, ?).
num_samples (int):
attributions (torch.tensor): The integrated gradients, of the same shape as the image.
Hint: Iterate over the path of images, remember what you did in 'get_custom_gradient'.
prediction = model.predict(image)
baseline = torch.zeros_like(image)
baseline = normalize_image(baseline)
path = get_path(image, baseline, num_samples)
grads = []
for p in path:
p.requires_grad = True
output = model(p)
grad = torch.autograd.grad(output[:, prediction], p)[0]
ig = (image - baseline) * (
return ig
if __name__ == "__main__":
image = load_test_image()
image_preprocessed = preprocess_image(image)
image_preprocessed_norm = normalize_image(image_preprocessed).unsqueeze(0)
model = ImageNetClassifier()
y_pred, y_prob = model.predict(image_preprocessed_norm, return_probs=True)
print(f'Predicted class: "{convert_idx_to_label(y_pred.item())}". Confidence: {y_prob.item() * 100:.2f}%')
assert y_pred == torch.tensor([13])
assert torch.allclose(y_prob, torch.tensor([0.9483]), atol=1e-4)
print('Run `get_custom_gradient` ...')
gradient = get_custom_gradient(model, image_preprocessed_norm.clone())
gradient_abs = get_custom_gradient(model, image_preprocessed_norm.clone(), absolute=True)
print('Run `get_custom_integrated_gradients` ...')
ig = get_custom_integrated_gradients(model, image_preprocessed_norm.clone(), num_samples=50)
print('Run `plot_attributions` ...')
attributions = [gradient, gradient_abs, ig]
attributions = [aggregate_attribution(attr) for attr in attributions]
attributions = [normalize_attribution(attr) for attr in attributions]
plot_attributions(plt, image_preprocessed, attributions, ['Gradient', 'abs. Gradient', 'IG'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment