Skip to content

Instantly share code, notes, and snippets.

@talhaanwarch
Created June 12, 2021 17:00
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save talhaanwarch/f7cb08037ff59bdcd85df57c61ee94b4 to your computer and use it in GitHub Desktop.
Save talhaanwarch/f7cb08037ff59bdcd85df57c61ee94b4 to your computer and use it in GitHub Desktop.
CALCULATE SENTENCE SIMILARITY using Pretrained BERT model
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 11 18:58:05 2021
# CALCULATE SENTENCE SIMILARITY
@author: TAC
"""
import torch#pytorch
from transformers import AutoTokenizer, AutoModel#for embeddings
from sklearn.metrics.pairwise import cosine_similarity#for similarity
#download pretrained model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased",)
model = AutoModel.from_pretrained("bert-base-uncased",output_hidden_states=True)
#create embeddings
def get_embeddings(text,token_length):
tokens=tokenizer(text,max_length=token_length,padding='max_length',truncation=True)
output=model(torch.tensor(tokens.input_ids).unsqueeze(0),
attention_mask=torch.tensor(tokens.attention_mask).unsqueeze(0)).hidden_states[-1]
return torch.mean(output,axis=1).detach().numpy()
#calculate similarity
def calculate_similarity(text1,text2,token_length=20):
text3=input('input you sentence \n')
out1=get_embeddings(text1,token_length=token_length)#create embeddings of text
out2=get_embeddings(text2,token_length=token_length)#create embeddings of text
out3=get_embeddings(text3,token_length=token_length)#create embeddings of text
sim1= cosine_similarity(out1,out3)[0][0]
sim2= cosine_similarity(out2,out3)[0][0]
print(sim1,sim2)
if sim1>sim2:
print('sentence 1 is more similar to input sentence')
else:
print('sentence 2 is more similar to input sentence')
text1='Before viewing the output, let understand the parameters the tokenizer takes'
text2='if the token length is smaller than the token in a sentence then remove some of the tokens to make them equal in length'
calculate_similarity(text1,text2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment