Skip to content

Instantly share code, notes, and snippets.

@narendraprasath
Created May 14, 2020 13:28
Show Gist options
  • Save narendraprasath/748db4e12dee2af878191dbbb290f287 to your computer and use it in GitHub Desktop.
Save narendraprasath/748db4e12dee2af878191dbbb290f287 to your computer and use it in GitHub Desktop.
class Embeddings():
def __init__(self, model_path):
self.model_path = model_path
self.model = None
self.__load_model__()
def __load_model__(self):
#word_vectors = api.load("glove-wiki-gigaword-100")
model_name = 'glove-twitter-25' #'word2vec-google-news-50' #'glove-twitter-25'
if not os.path.exists(self.model_path+ model_name):
print("Downloading model")
self.model = api.load(model_name)
self.model.save(self.model_path+ model_name)
else:
print("Loading model from Drive")
self.model = KeyedVectors.load(self.model_path+ model_name)
def get_oov_from_model(self, document_vocabulary):
## the below words are not available in our pre-trained model model_name
print("The below words are not found in our pre-trained model")
words = []
for word in set(document_vocabulary):
if word not in self.model:
words.append(word)
print(words)
def get_sentence_embeddings(self, data_df, column_name):
sentence_embeddings_list = []
for sentence in data_df[column_name]:
sentence_embeddings = np.repeat(0, self.model.vector_size)
try:
tokens = sentence.split(" ")
## get the word embedding
for word in tokens:
if word in self.model:
word_embedding = self.model[word]
else:
word_embedding = np.repeat(0, self.model.vector_size)
sentence_embeddings = sentence_embeddings + word_embedding
## take the average for sentence embeddings
#sentence_embeddings = sentence_embeddings / len(tokens)
sentence_embeddings_list.append(sentence_embeddings.reshape(1, -1))
except Exception as e:
print(e)
return sentence_embeddings_list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment