Last active
August 31, 2024 19:43
-
-
Save simonster/155894d48aef2bd36bd2dd8267e62391 to your computer and use it in GitHub Desktop.
Mean attention distance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2022 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
# Author: Maithra Raghu <maithra@google.com> | |
def compute_distance_matrix(patch_size, num_patches, length): | |
"""Helper function to compute distance matrix.""" | |
distance_matrix = np.zeros((num_patches, num_patches)) | |
for i in range(num_patches): | |
for j in range(num_patches): | |
if i == j: # zero distance | |
continue | |
xi, yi = (int(i/length)), (i % length) | |
xj, yj = (int(j/length)), (j % length) | |
distance_matrix[i, j] = patch_size*np.linalg.norm([xi - xj, yi - yj]) | |
return distance_matrix | |
def compute_mean_attention_dist(patch_size, attention_weights): | |
num_patches = attention_weights.shape[-1] | |
length = int(np.sqrt(num_patches)) | |
assert (length**2 == num_patches), ("Num patches is not perfect square") | |
distance_matrix = compute_distance_matrix(patch_size, num_patches, length) | |
h, w = distance_matrix.shape | |
distance_matrix = distance_matrix.reshape((1, 1, h, w)) | |
mean_distances = attention_weights*distance_matrix | |
mean_distances = np.sum(mean_distances, axis=-1) # sum along last axis to get average distance per token | |
mean_distances = np.mean(mean_distances, axis=-1) # now average across all the tokes | |
return mean_distances |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment