Last active
May 24, 2018 22:01
-
-
Save hanneshapke/0cf0605a8c8be83ce74239e133c4e52a to your computer and use it in GitHub Desktop.
Mimicking Gensim's KeyedVectors class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bz2 | |
import numpy as np | |
import pickle | |
from django.conf import settings | |
from django_redis import get_redis_connection | |
from gensim.models.keyedvectors import KeyedVectors | |
from .constants import GOOGLE_WORD2VEC_MODEL_NAME | |
from .redis import load_word2vec_model_into_redis, query_redis | |
class RedisKeyedVectors(KeyedVectors): | |
""" | |
Class to imitate gensim's KeyedVectors, but instead getting the vectors from the memory, the vectors | |
will be retrieved from a redis db | |
""" | |
def __init__(self, key=GOOGLE_WORD2VEC_MODEL_NAME): | |
self.rs = get_redis_connection(alias='word2vec') | |
self.syn0 = [] | |
self.syn0norm = None | |
self.check_vocab_len() | |
self.index2word = [] | |
self.key = key | |
@classmethod | |
def check_vocab_len(cls, key=GOOGLE_WORD2VEC_MODEL_NAME, **kwargs): | |
rs = get_redis_connection(alias='word2vec') | |
return len(list(rs.scan_iter(key + "*"))) | |
@classmethod | |
def load_word2vec_format(cls, **kwargs): | |
raise NotImplementedError("You can't load a word model that way. It needs to pre-loaded into redis") | |
def save(self, *args, **kwargs): | |
raise NotImplementedError("You can't write back to Redis that way.") | |
def save_word2vec_format(self, **kwargs): | |
raise NotImplementedError("You can't write back to Redis that way.") | |
def word_vec(self, word, **kwargs): | |
""" | |
This method is mimicking the word_vec method from the Gensim KeyedVector class. Instead of | |
looking it up from an in memory dict, it | |
- requests the value from the redis instance, where the key is a combination between the word vector | |
model key and the word itself | |
- decompresses it | |
- and finally unpickles it | |
:param word: string | |
:returns: numpy array of dim of the word vector model (for Google: 300, 1) | |
""" | |
try: | |
return pickle.loads(bz2.decompress(query_redis(self.rs, word))) | |
except TypeError: | |
return None | |
def __getitem__(self, words): | |
""" | |
returns numpy array for single word or vstack for multiple words | |
""" | |
if isinstance(words, str): | |
# allow calls like trained_model['Chief Executive Officer'] | |
return self.word_vec(words) | |
return np.vstack([self.word_vec(word) for word in words]) | |
def __contains__(self, word): | |
""" build in method to quickly check whether a word is available in redis """ | |
return self.rs.exists(self.key + word) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment