Last active
April 19, 2021 19:08
-
-
Save MathiasGruber/221606bcd8154910c2fc08a8ec6d8894 to your computer and use it in GitHub Desktop.
Creating a plot showing the attention/context of a semantic search through the calculation of the cosine similarity between words
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For each sentence, store a list of token embeddings; i.e. a 1024-dimensional vector for each token | |
for i, sentence in enumerate(valid_sentences): | |
tokens = tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][i]) | |
embeddings = model_output[0][i] | |
token_embeddings.append( | |
[{"token": token, "embedding": embedding.detach().numpy()} for token, embedding in zip(tokens, embeddings)] | |
) | |
def get_token_embeddings(embeddings_word): | |
"""Returns a list of tokens and list of embeddings""" | |
tokens, embeddings = [], [] | |
for word in embeddings_word: | |
if word['token'] not in ['<s>', '<pad>', '</pad>', '</s>']: | |
tokens.append(word['token'].replace('Ġ', '')) | |
embeddings.append(word['embedding']) | |
norm_embeddings = normalize(embeddings, norm='l2') | |
return tokens, norm_embeddings | |
# Get tokens & token embeddings for both query & search match | |
query_tokens, query_token_embeddings = get_word_scores(token_embeddings[QUERY_ID]) | |
match_tokens, match_token_embeddings = get_word_scores(token_embeddings[MATCH_ID] | |
# Calculate cosine between all token embeddings | |
attention = (query_word_embeddings @ match_word_embeddings.T) | |
# Plot the attention matrix with the tokens on x and y axes | |
plot_attention(match_tokens, query_tokens, attention) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment