Created
May 3, 2017 12:04
-
-
Save pminervini/bcea6b41d7f213759e180545455df7e3 to your computer and use it in GitHub Desktop.
Loading pre-trained embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import numpy as np | |
import logging | |
logger = logging.getLogger(__name__) | |
def load_glove(stream, words=None): | |
""" | |
Loads GloVe word embeddings. | |
:param stream: An opened stream to the GloVe file. | |
:param words: Words in the existing vocabulary. | |
:return: dict {word: embedding} | |
""" | |
word_to_embedding = {} | |
for n, line in enumerate(stream): | |
split_line = line.split(' ') | |
word = split_line[0] | |
if words is None or word in words: | |
try: | |
word_to_embedding[word] = [float(f) for f in split_line[1:]] | |
except ValueError: | |
logger.error('{}\t{}\t{}'.format(n, word, str(split_line))) | |
return word_to_embedding | |
def load_word2vec(stream, words=None): | |
""" | |
Loads word2vec word embeddings. | |
:param stream: An opened stream to the GloVe file. | |
:param words: Words in the existing vocabulary. | |
:return: dict {word: embedding} | |
""" | |
word_to_embedding = {} | |
vec_n, vec_size = map(int, stream.readline().split()) | |
byte_size = vec_size * 4 | |
for n in range(vec_n): | |
word = b'' | |
while True: | |
c = stream.read(1) | |
if c == b' ': | |
break | |
else: | |
word += c | |
word = word.decode('utf-8') | |
vector = np.fromstring(stream.read(byte_size), dtype=np.float32) | |
if words is None or word in words: | |
word_to_embedding[word] = vector.tolist() | |
return word_to_embedding | |
# Sample usage: | |
def main(): | |
if os.path.isfile(glove_path): | |
from derte.io.embeddings import load_glove | |
logger.info('Initialising the embeddings with GloVe vectors ..') | |
word_set = {w for w, w_idx in qs_tokenizer.word_index.items() | |
if w_idx < vocab_size} | |
with open(glove_path, 'r') as stream: | |
word_to_embedding = load_glove(stream=stream, words=word_set) | |
for word in tqdm(word_to_embedding): | |
word_idx, word_embedding = qs_tokenizer.word_index[word], word_to_embedding[word] | |
session.run(assign_word_embedding, feed_dict={ | |
word_idx_ph: word_idx, | |
word_embedding_ph: word_embedding | |
}) | |
logger.info('Done!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Forgot this - here's the
assign
graph: