Skip to content

Instantly share code, notes, and snippets.

@duhaime
Created June 17, 2019 15:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save duhaime/597bd1c7dbe9115bba5b4bd42cddbcb9 to your computer and use it in GitHub Desktop.
Save duhaime/597bd1c7dbe9115bba5b4bd42cddbcb9 to your computer and use it in GitHub Desktop.
Find KNN for Gensim Model
from annoy import AnnoyIndex
import json
import os
def build_ann_index(model):
'''Build an ANN model and persist to disk for faster vector similarity queries'''
words = list(model.wv.vocab.keys()) # list of strings, one per word
idx_to_word = {str(idx): i for idx, i in enumerate(words)} # d[word] = word_idx in words
dims = model.wv[words[0]].shape[0] # number of dimensions in each input vector
# create the approximate nearest neighbors model
if not os.path.exists('model.ann'):
ann = AnnoyIndex(dims)
for i in idx_to_word:
ann.add_item(int(i), model.wv[idx_to_word[i]])
ann.build(10) # number of 'trees' to build
ann.save('model.ann')
if not os.path.exists('idx_to_word.json'):
with open('idx_to_word.json', 'w') as out: json.dump(idx_to_word, out)
# load the saved model
ann = AnnoyIndex(dims)
ann.load('model.ann')
idx_to_word = json.load(open('idx_to_word.json'))
# return the model for querying and the map from word to idx
return ann, idx_to_word
def find_centroid(words):
'''Given a list of words, get the centroid of those word's vectors'''
vecs = np.vstack([model.wv[w] for w in words if w in model.wv])
sums = np.array([ np.sum(vecs[:,idx]) for idx, i in enumerate(range(vecs[0].shape[0])) ])
return sums / vecs[0].shape[0]
def find_similar_by_vec(vec, n=50):
'''Return the words for the `n` most similar words to a query vector'''
indices = ann.get_nns_by_vector(vec, n*2**2, search_k=-1, include_distances=False)
similar_words = [idx_to_word[str(i)] for i in indices]
curated = []
for idx, word in enumerate(similar_words):
if len(curated) < n:
if similar_words[idx].lower() not in curated:
curated.append(similar_words[idx].lower())
return curated
def find_similar_by_words(words, n=50):
'''Return the words for the `n` most similar words to a list of query words'''
centroid = find_centroid(words)
indices = ann.get_nns_by_vector(centroid, n*2**2, search_k=-1, include_distances=False)
similar_words = [idx_to_word[str(i)] for i in indices]
curated = []
for idx, word in enumerate(similar_words):
if len(curated) < n:
if similar_words[idx].lower() not in words and similar_words[idx].lower() not in curated:
curated.append(similar_words[idx].lower())
return curated
# prepare data structures that will expedite the process of finding words similar to a query vector
ann, idx_to_word = build_ann_index(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment