Skip to content

Instantly share code, notes, and snippets.

@jeanCarloMachado
Last active November 29, 2021 16:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeanCarloMachado/29cd95c2c71401210604dcac9aca5197 to your computer and use it in GitHub Desktop.
Save jeanCarloMachado/29cd95c2c71401210604dcac9aca5197 to your computer and use it in GitHub Desktop.
A simpel script to rank strings based on their similarity with a query.
#!/usr/bin/env python
"""
Bert similarity script.
Save this file as bert_similarity.py and then run like this:
python bert_similarity.py rank --query "open source" "linux" "windows" "mac os"
And get a result like this:
{'similarity': 0.5847581, 'text': 'linux'}
{'similarity': 0.52185637, 'text': 'windows'}
{'similarity': 0.5135084, 'text': 'mac os'}
Dependencies:
pip install sentence-transformers fire numpy pandas sklearn
"""
import fire
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
class BertSimilarity():
def rank(self, query: str, *passed_docs):
"""
The first string is the query, and all others are documents.
Use quotes to separate each document.
"""
passed_docs = list(passed_docs)
model = SentenceTransformer('bert-base-nli-mean-tokens')
text_embeddings = model.encode([query] + passed_docs, batch_size = 8)
similarities = cosine_similarity(text_embeddings)
result = []
embeddings_index = 1
for passed_doc in passed_docs:
result.append({"similarity": similarities[0][embeddings_index], "text": passed_doc})
embeddings_index = 1 + embeddings_index
# sort most similars on top
return sorted(result, key=lambda x: x['similarity'], reverse=True)
if __name__ == '__main__':
fire.Fire(BertSimilarity)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment