Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
def get_token_indices(model, layer_name, threshold, matrix, y_labels):
heatmap = get_heatmap(model=model, layer_name=layer_name, matrix=matrix, y_labels=y_labels)
_, output_dim = get_conv_layer(model, layer_name)
# depending on the ration between the input and layer output shape, we need to calculate
# how many original tokens have contributed to the layer output
dim_ratio = matrix.shape[1] / output_dim
if dim_ratio < 1.5:
window_size = 1
window_size = 2
indices = {}
indices_above_threshold = np.where(heatmap > threshold)[0].tolist()
for i in indices_above_threshold:
scaled_index = i * int(dim_ratio)
for ind in range(scaled_index - window_size, scaled_index + window_size + 1):
if ind not in indices or indices[ind] < heatmap[i]:
indices.update({ind: heatmap[i]})
return indices
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.