Last active
April 4, 2024 16:41
-
-
Save howard-haowen/83874770957f15b84bd069dce0ce6303 to your computer and use it in GitHub Desktop.
Text similarity search using sentence_transformers and faiss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reference: kstathou/acl-search-engine | |
# !pip install faiss-cpu --no-cache | |
# !pip install sentence_transformers | |
import faiss | |
import numpy as np | |
import pandas as pd | |
import pickle | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
from pathlib import Path | |
# Instantiate the sentence-level DistilBERT (or other models supported by sentence_transformers) | |
model = SentenceTransformer('stsb-xlm-r-multilingual') | |
# Check if GPU is available and use it | |
if torch.cuda.is_available(): | |
model = model.to(torch.device("cuda")) | |
print(model.device) | |
def index_corpus(df, col="text"): | |
# Assume the corpus is stored in the 'text' column of a dataframe by default | |
corpus = df[col].to_list() | |
# Compute embeddings for all senteces in the corpus | |
embeddings_corpus = model.encode(corpus, show_progress_bar=True) | |
# Change data type of embeddings | |
embeddings = np.array([embedding for embedding in embeddings_corpus]).astype("float32") | |
# Instantiate the index with faiss | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
# Pass the index to IndexIDMap | |
index = faiss.IndexIDMap(index) | |
# Add vectors and their IDs, set to be DF indexes | |
index.add_with_ids(embeddings, df.index.values) | |
# Serialise index and store it as a pickle | |
with open("faiss_index.pickle", "wb") as f: | |
pickle.dump(faiss.serialize_index(index), f) | |
return index | |
def vector_search(query, num_results=10): | |
"""Tranforms query to vector using a pretrained, sentence-level | |
DistilBERT model and finds similar vectors using FAISS. | |
Args: | |
query (str): User query that should be more than a sentence long. | |
num_results (int): Number of results to return. | |
Returns: | |
L2 (:obj:`numpy.array` of `float`): L2 distance between results and query. | |
ID (:obj:`numpy.array` of `int`): ID of the results. | |
""" | |
vector = model.encode(list(query)) | |
L2, ID = index.search(np.array(vector).astype("float32"), k=num_results) | |
index_list = ID.flatten().tolist() | |
results = df.loc[index_list, :] | |
results['L2_distance'] = L2.flatten().tolist() | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment