Skip to content

Instantly share code, notes, and snippets.

@zlapp
Last active January 10, 2024 06:37
Show Gist options
  • Save zlapp/40126608b01a5732412da38277db9ff5 to your computer and use it in GitHub Desktop.
Save zlapp/40126608b01a5732412da38277db9ff5 to your computer and use it in GitHub Desktop.
# https://github.com/samiraabnar/attention_flow
# https://github.com/google-research/vision_transformer/issues/27
# https://github.com/google-research/vision_transformer/issues/18
# https://github.com/faustomorales/vit-keras/blob/65724adcfd3979067ce24734f08df0afa745637d/vit_keras/visualize.py#L7-L45
# https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from torchvision import transforms
import cv2
activation = {}
def get_attn_softmax(name):
def hook(model, input, output):
with torch.no_grad():
input = input[0]
B, N, C = input.shape
qkv = (
model.qkv(input)
.detach()
.reshape(B, N, 3, model.num_heads, C // model.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * model.scale
attn = attn.softmax(dim=-1)
activation[name] = attn
return hook
# expects timm vis transformer model
def add_attn_vis_hook(model):
for idx, module in enumerate(list(model.blocks.children())):
module.attn.register_forward_hook(get_attn_softmax(f"attn{idx}"))
def get_mask(im,att_mat):
# Average the attention weights across all heads.
# att_mat,_ = torch.max(att_mat, dim=1)
att_mat = torch.mean(att_mat, dim=1)
# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]
for n in range(1, aug_att_mat.size(0)):
joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
result = (mask * im).astype("uint8")
return result, joint_attentions, grid_size
def show_attention_map(model, img_path, shape):
add_attn_vis_hook(model)
im = Image.open(os.path.expandvars(img_path))
im = im.resize((shape, shape))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
logits = model(transform(im).unsqueeze(0))
attn_weights_list = list(activation.values())
result, joint_attentions, grid_size = get_mask(im,torch.cat(attn_weights_list))
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(im)
_ = ax2.imshow(result)
probs = torch.nn.Softmax(dim=-1)(logits)
top5 = torch.argsort(probs, dim=-1, descending=True)
print("Prediction Label and Attention Map!\n")
for idx in top5[0, :5]:
print(f'{probs[0, idx.item()]:.5f} : {idx.item()}', end='')
for i, v in enumerate(joint_attentions):
# Attention from the output token to the input space.
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
result = (mask * im).astype("uint8")
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
ax1.set_title('Original')
ax2.set_title('Attention Map_%d Layer' % (i+1))
_ = ax1.imshow(im)
_ = ax2.imshow(result)
plt.show()
if __name__ == "__main__":
import os
import sys
import timm
model_names = timm.list_models("vit*")
for model_name in model_names:
print(f"\n{model_name}\n")
m = timm.create_model(model_name, pretrained=True)
shape = eval(model_name[-3:])
show_attention_map(m, sys.argv[1], shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment