Skip to content

Instantly share code, notes, and snippets.

@shivahari
Last active September 18, 2023 08:48
Show Gist options
  • Save shivahari/bb944467a9ca46d653041279464cf2c0 to your computer and use it in GitHub Desktop.
Save shivahari/bb944467a9ca46d653041279464cf2c0 to your computer and use it in GitHub Desktop.
FAISS_semantic_search_app
# reference https://deepnote.com/blog/semantic-search-using-faiss-and-mpnet
"""
A semantic search tool
"""
import pickle
from pathlib import Path
import faiss
import torch
from bs4 import BeautifulSoup
from transformers import AutoTokenizer, AutoModel
class SemanticEmbedding:
"A semantic embedding object to get the word embeddings"
def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
"object initialization"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def mean_pooling(self, model_output, attention_mask):
"""
Mean Pooling - Take attention mask into account for correct averaging
Although this is very useful to create a vector for a sentence,
it is useful in our case, where we use a word alone
"""
#First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_exp = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings*input_mask_exp,1)/torch.clamp(input_mask_exp.sum(1),
min=1e-9)
def get_embedding(self, word):
"create word embeddings"
encoded_input = self.tokenizer(word, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling
word_embedding = self.mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
word_embedding = torch.nn.functional.normalize(word_embedding, p=2, dim=1)
return word_embedding.detach().numpy()
class FaissIdxObject:
"A FAISS object to create,add-doc, search and save an index"
def __init__(self, dim=768):
"object initialize"
self.dim = dim
self.ctr = 0
def create_index(self):
"Create a new index"
return faiss.IndexFlatIP(self.dim)
@staticmethod
def get_index(index_name):
"Get the index"
try:
return faiss.read_index(index_name)
except FileNotFoundError as err:
raise f"Unable to find {index_name}, does the file exist? from {err}"
@staticmethod
def add_doc_to_index(index, embedded_document_text):
"Add doc to index"
index.add(embedded_document_text)
@staticmethod
def search_index(embedded_query, index, doc_map, k=5, return_scores=False):
"Search through the index"
D, I = index.search(embedded_query, k)
if return_scores:
value = [{doc_map[idx]: str(score)} for idx, score in zip(I[0], D[0]) if idx in doc_map]
else:
value = [doc_map[idx] for idx, score in zip(I[0], D[0]) if idx in doc_map]
return value
@staticmethod
def save_index(index, index_name):
"Save the index and dataset pickle file to local"
try:
faiss.write_index(index, index_name)
except Exception as err:
raise err
class PickleObject:
"A pickle object to save and read the humanreadable dataset"
def create_dict(self):
"Create a new dict"
return {}
@staticmethod
def get_pickle(pickle_name):
"Get the local pickle file"
try:
with open(pickle_name, 'rb') as pickled_file:
return pickle.load(pickled_file)
except FileNotFoundError as err:
raise f"Unable to find {pickle_name}, does the file exist? from {err}"
@staticmethod
def add_doc_to_pickle(pickle_dict, counter, doc):
"Add entry to the pickle"
pickle_dict[counter] = doc
@staticmethod
def save_pickle(pickle_file, pickle_name):
"Save the pickle file to local"
try:
with open(pickle_name, 'wb') as pf:
pickle.dump(pickle_file, pf, protocol=pickle.HIGHEST_PROTOCOL)
except Exception as err:
raise err
class XMLReader:
"An XML object to read the values from an XML file"
@staticmethod
def read_from_file(xml_file, html_property):
"Read from XML file"
html_property = html_property.lower()
with open(xml_file, 'r') as xmlfile:
xml = xmlfile.readlines()
xml = "".join(xml)
soup = BeautifulSoup(xml, "html.parser")
rows = soup.find_all('row')
return [ row[html_property] for row in rows]
if __name__ == '__main__':
embedder = SemanticEmbedding()
if not Path('Tags.index').is_file() or not Path('Tags.pickle').is_file():
faiss_obj = FaissIdxObject()
pickle_obj = PickleObject()
xml_reader = XMLReader()
faiss_index = faiss_obj.create_index()
doc_dict = pickle_obj.create_dict()
input_rows = xml_reader.read_from_file(xml_file='Tags.xml',
html_property='TagName')
COUNTER = 0
for row in input_rows:
embedded_content = embedder.get_embedding(row)
faiss_obj.add_doc_to_index(index=faiss_index,
embedded_document_text=embedded_content)
pickle_obj.add_doc_to_pickle(pickle_dict=doc_dict,
counter=COUNTER,
doc=row)
COUNTER += 1
faiss_obj.save_index(index=faiss_index, index_name='Tags.index')
pickle_obj.save_pickle(pickle_file=doc_dict,pickle_name='Tags.pickle')
else:
faiss_index = FaissIdxObject.get_index(index_name='Tags.index')
doc_dict = PickleObject.get_pickle(pickle_name='Tags.pickle')
while True:
tech = input("\nEnter a tech: ")
if tech == "exit":
break
if tech.strip() == "":
continue
embedded_input = embedder.get_embedding(tech)
output = FaissIdxObject.search_index(embedded_query=embedded_input,
index=faiss_index,
doc_map=doc_dict,
k=10,
return_scores=True)
print(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment