Skip to content

Instantly share code, notes, and snippets.

@howard-haowen
Last active April 4, 2024 16:41
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save howard-haowen/83874770957f15b84bd069dce0ce6303 to your computer and use it in GitHub Desktop.
Save howard-haowen/83874770957f15b84bd069dce0ce6303 to your computer and use it in GitHub Desktop.
Text similarity search using sentence_transformers and faiss
# Reference: kstathou/acl-search-engine
# !pip install faiss-cpu --no-cache
# !pip install sentence_transformers
import faiss
import numpy as np
import pandas as pd
import pickle
import torch
from sentence_transformers import SentenceTransformer, util
from pathlib import Path
# Instantiate the sentence-level DistilBERT (or other models supported by sentence_transformers)
model = SentenceTransformer('stsb-xlm-r-multilingual')
# Check if GPU is available and use it
if torch.cuda.is_available():
model = model.to(torch.device("cuda"))
print(model.device)
def index_corpus(df, col="text"):
# Assume the corpus is stored in the 'text' column of a dataframe by default
corpus = df[col].to_list()
# Compute embeddings for all senteces in the corpus
embeddings_corpus = model.encode(corpus, show_progress_bar=True)
# Change data type of embeddings
embeddings = np.array([embedding for embedding in embeddings_corpus]).astype("float32")
# Instantiate the index with faiss
index = faiss.IndexFlatL2(embeddings.shape[1])
# Pass the index to IndexIDMap
index = faiss.IndexIDMap(index)
# Add vectors and their IDs, set to be DF indexes
index.add_with_ids(embeddings, df.index.values)
# Serialise index and store it as a pickle
with open("faiss_index.pickle", "wb") as f:
pickle.dump(faiss.serialize_index(index), f)
return index
def vector_search(query, num_results=10):
"""Tranforms query to vector using a pretrained, sentence-level
DistilBERT model and finds similar vectors using FAISS.
Args:
query (str): User query that should be more than a sentence long.
num_results (int): Number of results to return.
Returns:
L2 (:obj:`numpy.array` of `float`): L2 distance between results and query.
ID (:obj:`numpy.array` of `int`): ID of the results.
"""
vector = model.encode(list(query))
L2, ID = index.search(np.array(vector).astype("float32"), k=num_results)
index_list = ID.flatten().tolist()
results = df.loc[index_list, :]
results['L2_distance'] = L2.flatten().tolist()
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment