Skip to content

Instantly share code, notes, and snippets.

@goncalossilva
Created February 2, 2023 10:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goncalossilva/d165c7e10a6df4eb7d3c6b8236570250 to your computer and use it in GitHub Desktop.
Save goncalossilva/d165c7e10a6df4eb7d3c6b8236570250 to your computer and use it in GitHub Desktop.
sentence-transformers-search.py
import json
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from string import punctuation
import gzip
import os
import sys
import re
import torch
def index_markdown_files(dir_path):
# Use the Bi-Encoder to encode all passages, so that we can use it with sematic search.
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256 # Truncate long passages to 256 tokens.
# The bi-encoder will retrieve 100 documents. Use a cross-encoder, to re-rank the results list to improve the quality.
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# Punctuation regexp.
re_punctuation = re.compile(r'[\s{}]+'.format(re.escape(punctuation)))
# Walk through the directory tree
passages = []
for root, dirs, files in os.walk(dir_path):
for file in files:
print(".", end="", flush=True)
# Check if file is a markdown file
if file.endswith(".md"):
# Read paragraphs
with open(os.path.join(root, file), "r") as f:
content = f.read()
# Add paragraphs.
passages.extend([p for p in content.split('\n\n') if not p.startswith("#")])
# Add sentences.
passages.extend(re_punctuation.split(content))
print("Passages:", len(passages))
# Encode all passages into our vector space.
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
return (passages, bi_encoder, cross_encoder, corpus_embeddings)
def search_markdown_files(passages, bi_encoder, cross_encoder, corpus_embeddings, query):
top_k = 32 # Number of passages we want to retrieve with the bi-encoder.
# Encode the query using the bi-encoder and find potentially relevant passages
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
#question_embedding = question_embedding.cuda()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = hits[0] # Get the hits for the first query
# Score all retrieved passages with the cross_encoder
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hit = hits[0]
print(passages[hit['corpus_id']].replace("\n", " "))
# Pass root folder as arg, e.g., python sentence-transformers-search.py docs/
if __name__ == "__main__":
(passages, bi_encoder, cross_encoder, corpus_embeddings) = index_markdown_files(sys.argv[1])
while True:
query = input("Q: ")
print("A: ", end="")
search_markdown_files(passages, bi_encoder, cross_encoder, corpus_embeddings, query)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment