Last active
February 3, 2020 19:11
-
-
Save epeters3/437c5607703da7f03c681583ff2fe1ea to your computer and use it in GitHub Desktop.
Fast Querying of Word Embeddings Using `sklearn.neighbors.BallTree`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import typing as t | |
import pickle | |
import os | |
import numpy as np | |
from tqdm import tqdm | |
from sklearn.neighbors import BallTree | |
# The embedded word vectors that work with this gist can be downloaded at | |
# https://fasttext.cc/docs/en/english-vectors.html. It's the | |
# `wiki-news-300d-1M.vec.zip` embeddings. | |
class EmbeddingSpace: | |
def __init__(self, fname: str, *, limit: int = None) -> None: | |
fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore') | |
self.n, self.d = map(int, fin.readline().split()) | |
if limit is not None: | |
self.n = limit | |
self.word2index: t.Dict[str,int] = {} | |
self.index2word: t.Dict[int, str] = {} | |
self.vecs = np.zeros((self.n, self.d)) | |
with tqdm(total=self.n) as pbar: | |
for i, line in enumerate(fin): | |
tokens = line.rstrip().split(' ') | |
word = tokens[0] | |
self.word2index[word] = i | |
self.index2word[i] = word | |
self.vecs[i] = [float(token) for token in tokens[1:]] | |
if i >= self.n - 1: | |
break | |
pbar.update(1) | |
self.tree = BallTree(self.vecs) | |
def get_vec(self, word: str) -> np.ndarray: | |
self._validate_word(word) | |
i = self.word2index[word] | |
return self.vecs[i] | |
def get_k_nearest(self, word: t.Union[str, np.ndarray], k: int) -> t.List[str]: | |
if isinstance(word, str): | |
vec = self.get_vec(word) | |
else: | |
vec = word | |
indices = self.tree.query(np.reshape(vec, (1, self.d)), k, return_distance=False)[0] | |
return [self.index2word[i] for i in indices] | |
def _validate_word(self, word: str) -> None: | |
if word not in self.word2index: | |
raise ValueError(f"'{word}' is not a valid word") | |
if __name__ == "__main__": | |
VECS_NAME = "wiki-news-300d-1M" | |
REFRESH = False | |
pkl_path = f"{VECS_NAME}.pkl" | |
if not REFRESH and os.path.exists(pkl_path): | |
# Use already built version | |
space = pickle.load(pkl_path) | |
else: | |
space = EmbeddingSpace(f"{VECS_NAME}.vec") | |
with open(pkl_path, "wb") as f: | |
pickle.dump(space, f) | |
print("**** Done. The `space` object is now available.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment