Last active
September 18, 2023 08:48
-
-
Save shivahari/bb944467a9ca46d653041279464cf2c0 to your computer and use it in GitHub Desktop.
FAISS_semantic_search_app
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# reference https://deepnote.com/blog/semantic-search-using-faiss-and-mpnet | |
""" | |
A semantic search tool | |
""" | |
import pickle | |
from pathlib import Path | |
import faiss | |
import torch | |
from bs4 import BeautifulSoup | |
from transformers import AutoTokenizer, AutoModel | |
class SemanticEmbedding: | |
"A semantic embedding object to get the word embeddings" | |
def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'): | |
"object initialization" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModel.from_pretrained(model_name) | |
def mean_pooling(self, model_output, attention_mask): | |
""" | |
Mean Pooling - Take attention mask into account for correct averaging | |
Although this is very useful to create a vector for a sentence, | |
it is useful in our case, where we use a word alone | |
""" | |
#First element of model_output contains all token embeddings | |
token_embeddings = model_output[0] | |
input_mask_exp = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings*input_mask_exp,1)/torch.clamp(input_mask_exp.sum(1), | |
min=1e-9) | |
def get_embedding(self, word): | |
"create word embeddings" | |
encoded_input = self.tokenizer(word, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = self.model(**encoded_input) | |
# Perform pooling | |
word_embedding = self.mean_pooling(model_output, encoded_input['attention_mask']) | |
# Normalize embeddings | |
word_embedding = torch.nn.functional.normalize(word_embedding, p=2, dim=1) | |
return word_embedding.detach().numpy() | |
class FaissIdxObject: | |
"A FAISS object to create,add-doc, search and save an index" | |
def __init__(self, dim=768): | |
"object initialize" | |
self.dim = dim | |
self.ctr = 0 | |
def create_index(self): | |
"Create a new index" | |
return faiss.IndexFlatIP(self.dim) | |
@staticmethod | |
def get_index(index_name): | |
"Get the index" | |
try: | |
return faiss.read_index(index_name) | |
except FileNotFoundError as err: | |
raise f"Unable to find {index_name}, does the file exist? from {err}" | |
@staticmethod | |
def add_doc_to_index(index, embedded_document_text): | |
"Add doc to index" | |
index.add(embedded_document_text) | |
@staticmethod | |
def search_index(embedded_query, index, doc_map, k=5, return_scores=False): | |
"Search through the index" | |
D, I = index.search(embedded_query, k) | |
if return_scores: | |
value = [{doc_map[idx]: str(score)} for idx, score in zip(I[0], D[0]) if idx in doc_map] | |
else: | |
value = [doc_map[idx] for idx, score in zip(I[0], D[0]) if idx in doc_map] | |
return value | |
@staticmethod | |
def save_index(index, index_name): | |
"Save the index and dataset pickle file to local" | |
try: | |
faiss.write_index(index, index_name) | |
except Exception as err: | |
raise err | |
class PickleObject: | |
"A pickle object to save and read the humanreadable dataset" | |
def create_dict(self): | |
"Create a new dict" | |
return {} | |
@staticmethod | |
def get_pickle(pickle_name): | |
"Get the local pickle file" | |
try: | |
with open(pickle_name, 'rb') as pickled_file: | |
return pickle.load(pickled_file) | |
except FileNotFoundError as err: | |
raise f"Unable to find {pickle_name}, does the file exist? from {err}" | |
@staticmethod | |
def add_doc_to_pickle(pickle_dict, counter, doc): | |
"Add entry to the pickle" | |
pickle_dict[counter] = doc | |
@staticmethod | |
def save_pickle(pickle_file, pickle_name): | |
"Save the pickle file to local" | |
try: | |
with open(pickle_name, 'wb') as pf: | |
pickle.dump(pickle_file, pf, protocol=pickle.HIGHEST_PROTOCOL) | |
except Exception as err: | |
raise err | |
class XMLReader: | |
"An XML object to read the values from an XML file" | |
@staticmethod | |
def read_from_file(xml_file, html_property): | |
"Read from XML file" | |
html_property = html_property.lower() | |
with open(xml_file, 'r') as xmlfile: | |
xml = xmlfile.readlines() | |
xml = "".join(xml) | |
soup = BeautifulSoup(xml, "html.parser") | |
rows = soup.find_all('row') | |
return [ row[html_property] for row in rows] | |
if __name__ == '__main__': | |
embedder = SemanticEmbedding() | |
if not Path('Tags.index').is_file() or not Path('Tags.pickle').is_file(): | |
faiss_obj = FaissIdxObject() | |
pickle_obj = PickleObject() | |
xml_reader = XMLReader() | |
faiss_index = faiss_obj.create_index() | |
doc_dict = pickle_obj.create_dict() | |
input_rows = xml_reader.read_from_file(xml_file='Tags.xml', | |
html_property='TagName') | |
COUNTER = 0 | |
for row in input_rows: | |
embedded_content = embedder.get_embedding(row) | |
faiss_obj.add_doc_to_index(index=faiss_index, | |
embedded_document_text=embedded_content) | |
pickle_obj.add_doc_to_pickle(pickle_dict=doc_dict, | |
counter=COUNTER, | |
doc=row) | |
COUNTER += 1 | |
faiss_obj.save_index(index=faiss_index, index_name='Tags.index') | |
pickle_obj.save_pickle(pickle_file=doc_dict,pickle_name='Tags.pickle') | |
else: | |
faiss_index = FaissIdxObject.get_index(index_name='Tags.index') | |
doc_dict = PickleObject.get_pickle(pickle_name='Tags.pickle') | |
while True: | |
tech = input("\nEnter a tech: ") | |
if tech == "exit": | |
break | |
if tech.strip() == "": | |
continue | |
embedded_input = embedder.get_embedding(tech) | |
output = FaissIdxObject.search_index(embedded_query=embedded_input, | |
index=faiss_index, | |
doc_map=doc_dict, | |
k=10, | |
return_scores=True) | |
print(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment