Skip to content

Instantly share code, notes, and snippets.

@xhluca
Created March 11, 2021 19:21
Show Gist options
  • Save xhluca/157920dd54a3b959f3a5ad6097803f48 to your computer and use it in GitHub Desktop.
Save xhluca/157920dd54a3b959f3a5ad6097803f48 to your computer and use it in GitHub Desktop.
import json
from typing import List
try:
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
except:
error_msg = (
"Couldn't import scikit-learn. To use the toy models, you will need to "
"install it with `pip install scikit-learn`."
)
raise Exception(error_msg)
import numpy as np
class SearchEngine:
def __init__(self):
self.svd = TruncatedSVD(300)
self.vectorizer = TfidfVectorizer(max_df=0.9, min_df=1)
def build_knowledge_base(self, passages: List[dict]):
self.passages = np.array(passages)
self.contents = np.array([p["content"] for p in self.passages])
content_tfidf = self.vectorizer.fit_transform(self.contents)
self.content_encs = self.svd.fit_transform(content_tfidf)
def retrieve_idx(self, query: str, k: int = 10) -> List[int]:
enc = self.vectorizer.transform([query])
ls = self.svd.transform(enc)
sim_scores = cosine_similarity(ls, self.content_encs).squeeze()
best_idx = sim_scores.argsort()[::-1][:k].tolist()
return best_idx
def retrieve(self, query: str, k: int = 10) -> List[dict]:
best_idx = self.retrieve_idx(query, k)
best_candidates = self.passages[best_idx].tolist()
return best_candidates
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment