Created
August 9, 2018 14:29
-
-
Save jojonki/5e2fb9d7982ee42827c87e2930ae73c5 to your computer and use it in GitHub Desktop.
Glove embeddings in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torch import nn | |
class Model(nn.Module): | |
def __init__(self, vocab_size, embd_size, pre_embd_w=None): | |
super(Model, self).__init__() | |
self.embd = nn.Embedding(vocab_size, embd_size, padding_idx=0) | |
if pre_embd_w is not None: | |
print('pre embedding weight is set') | |
self.embd.weight = nn.Parameter(pre_embd_w, requires_grad=True) | |
# self.embd.weight = nn.Parameter(pre_embd_w, requires_grad=False) # fix paramters | |
def load_glove_weights(glove_fpath, embd_size, vocab_size, word_index): | |
embeddings_index = {} | |
with open(glove_fpath) as f: | |
for line in f: | |
values = line.split() | |
word = values[0] | |
vector = np.array(values[1:], dtype='float32') | |
embeddings_index[word] = vector | |
print('Found {} word vectors in glove.'.format(len(embeddings_index))) | |
embedding_matrix = np.zeros((vocab_size, embd_size)) | |
print('embed_matrix.shape', embedding_matrix.shape) | |
found_ct = 0 | |
for word, i in word_index.items(): | |
embedding_vector = embeddings_index.get(word) | |
# words not found in embedding index will be all-zeros. | |
if embedding_vector is not None: | |
embedding_matrix[i] = embedding_vector | |
found_ct += 1 | |
print('{} words are found in glove'.format(found_ct)) | |
return embedding_matrix | |
# dummy parameters | |
embd_size = 200 | |
vocabs = list(sorted(['apple', 'banana', 'lemon'])) | |
w2i = {w: i for i, w in enumerate(vocabs)} | |
vocab_size = len(vocabs) | |
print('Loading pre-embedding weights...') | |
# download the dict from https://nlp.stanford.edu/projects/glove/ | |
g_weigts = load_glove_weights('./dict/glove.twitter.27B.200d.txt', embd_size, vocab_size, w2i) | |
print('Loaded!') | |
glove_embd = torch.from_numpy(g_weigts).type(torch.FloatTensor) | |
# save_pickle(glove_embd, './dict/glove_embd.pickle') # loading weights as a pickle is much faster | |
pre_embd = glove_embd | |
m = Model(vocab_size, embd_size, pre_embd) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment