Skip to content

Instantly share code, notes, and snippets.

@epeters3
Last active February 3, 2020 19:11
Show Gist options
  • Save epeters3/437c5607703da7f03c681583ff2fe1ea to your computer and use it in GitHub Desktop.
Save epeters3/437c5607703da7f03c681583ff2fe1ea to your computer and use it in GitHub Desktop.
Fast Querying of Word Embeddings Using `sklearn.neighbors.BallTree`
import io
import typing as t
import pickle
import os
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import BallTree
# The embedded word vectors that work with this gist can be downloaded at
# https://fasttext.cc/docs/en/english-vectors.html. It's the
# `wiki-news-300d-1M.vec.zip` embeddings.
class EmbeddingSpace:
def __init__(self, fname: str, *, limit: int = None) -> None:
fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
self.n, self.d = map(int, fin.readline().split())
if limit is not None:
self.n = limit
self.word2index: t.Dict[str,int] = {}
self.index2word: t.Dict[int, str] = {}
self.vecs = np.zeros((self.n, self.d))
with tqdm(total=self.n) as pbar:
for i, line in enumerate(fin):
tokens = line.rstrip().split(' ')
word = tokens[0]
self.word2index[word] = i
self.index2word[i] = word
self.vecs[i] = [float(token) for token in tokens[1:]]
if i >= self.n - 1:
break
pbar.update(1)
self.tree = BallTree(self.vecs)
def get_vec(self, word: str) -> np.ndarray:
self._validate_word(word)
i = self.word2index[word]
return self.vecs[i]
def get_k_nearest(self, word: t.Union[str, np.ndarray], k: int) -> t.List[str]:
if isinstance(word, str):
vec = self.get_vec(word)
else:
vec = word
indices = self.tree.query(np.reshape(vec, (1, self.d)), k, return_distance=False)[0]
return [self.index2word[i] for i in indices]
def _validate_word(self, word: str) -> None:
if word not in self.word2index:
raise ValueError(f"'{word}' is not a valid word")
if __name__ == "__main__":
VECS_NAME = "wiki-news-300d-1M"
REFRESH = False
pkl_path = f"{VECS_NAME}.pkl"
if not REFRESH and os.path.exists(pkl_path):
# Use already built version
space = pickle.load(pkl_path)
else:
space = EmbeddingSpace(f"{VECS_NAME}.vec")
with open(pkl_path, "wb") as f:
pickle.dump(space, f)
print("**** Done. The `space` object is now available.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment