Mean attention distance
# Copyright 2022 Google LLC.
# SPDX-License-Identifier: Apache-2.0
# Author: Maithra Raghu <>
def compute_distance_matrix(patch_size, num_patches, length):
"""Helper function to compute distance matrix."""
distance_matrix = np.zeros((num_patches, num_patches))
for i in range(num_patches):
for j in range(num_patches):
if i == j: # zero distance
xi, yi = (int(i/length)), (i % length)
xj, yj = (int(j/length)), (j % length)
distance_matrix[i, j] = patch_size*np.linalg.norm([xi - xj, yi - yj])
return distance_matrix
def compute_mean_attention_dist(patch_size, attention_weights):
num_patches = attention_weights.shape[-1]
length = int(np.sqrt(num_patches))
assert (length**2 == num_patches), ("Num patches is not perfect square")
distance_matrix = compute_distance_matrix(patch_size, num_patches, length)
h, w = distance_matrix.shape
distance_matrix = distance_matrix.reshape((1, 1, h, w))
mean_distances = attention_weights*distance_matrix
mean_distances = np.sum(mean_distances, axis=-1) # sum along last axis to get average distance per token
mean_distances = np.mean(mean_distances, axis=-1) # now average across all the tokes
return mean_distances
