Last active
April 16, 2018 02:47
-
-
Save hanneshapke/11a7ec5a236e41b74aaae74c74a8856e to your computer and use it in GitHub Desktop.
get_token_indicies
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_token_indices(model, layer_name, threshold, matrix, y_labels): | |
heatmap = get_heatmap(model=model, layer_name=layer_name, matrix=matrix, y_labels=y_labels) | |
_, output_dim = get_conv_layer(model, layer_name) | |
# depending on the ration between the input and layer output shape, we need to calculate | |
# how many original tokens have contributed to the layer output | |
dim_ratio = matrix.shape[1] / output_dim | |
if dim_ratio < 1.5: | |
window_size = 1 | |
else: | |
window_size = 2 | |
indices = {} | |
indices_above_threshold = np.where(heatmap > threshold)[0].tolist() | |
for i in indices_above_threshold: | |
scaled_index = i * int(dim_ratio) | |
for ind in range(scaled_index - window_size, scaled_index + window_size + 1): | |
if ind not in indices or indices[ind] < heatmap[i]: | |
indices.update({ind: heatmap[i]}) | |
return indices |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment