Skip to content

Instantly share code, notes, and snippets.

@MathiasGruber
Last active April 19, 2021 19:08
Show Gist options
  • Save MathiasGruber/221606bcd8154910c2fc08a8ec6d8894 to your computer and use it in GitHub Desktop.
Save MathiasGruber/221606bcd8154910c2fc08a8ec6d8894 to your computer and use it in GitHub Desktop.
Creating a plot showing the attention/context of a semantic search through the calculation of the cosine similarity between words
# For each sentence, store a list of token embeddings; i.e. a 1024-dimensional vector for each token
for i, sentence in enumerate(valid_sentences):
tokens = tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][i])
embeddings = model_output[0][i]
token_embeddings.append(
[{"token": token, "embedding": embedding.detach().numpy()} for token, embedding in zip(tokens, embeddings)]
)
def get_token_embeddings(embeddings_word):
"""Returns a list of tokens and list of embeddings"""
tokens, embeddings = [], []
for word in embeddings_word:
if word['token'] not in ['<s>', '<pad>', '</pad>', '</s>']:
tokens.append(word['token'].replace('Ġ', ''))
embeddings.append(word['embedding'])
norm_embeddings = normalize(embeddings, norm='l2')
return tokens, norm_embeddings
# Get tokens & token embeddings for both query & search match
query_tokens, query_token_embeddings = get_word_scores(token_embeddings[QUERY_ID])
match_tokens, match_token_embeddings = get_word_scores(token_embeddings[MATCH_ID]
# Calculate cosine between all token embeddings
attention = (query_word_embeddings @ match_word_embeddings.T)
# Plot the attention matrix with the tokens on x and y axes
plot_attention(match_tokens, query_tokens, attention)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment