Skip to content

Instantly share code, notes, and snippets.

@jojonki
Created August 9, 2018 14:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jojonki/5e2fb9d7982ee42827c87e2930ae73c5 to your computer and use it in GitHub Desktop.
Save jojonki/5e2fb9d7982ee42827c87e2930ae73c5 to your computer and use it in GitHub Desktop.
Glove embeddings in PyTorch
import numpy as np
import torch
from torch import nn
class Model(nn.Module):
def __init__(self, vocab_size, embd_size, pre_embd_w=None):
super(Model, self).__init__()
self.embd = nn.Embedding(vocab_size, embd_size, padding_idx=0)
if pre_embd_w is not None:
print('pre embedding weight is set')
self.embd.weight = nn.Parameter(pre_embd_w, requires_grad=True)
# self.embd.weight = nn.Parameter(pre_embd_w, requires_grad=False) # fix paramters
def load_glove_weights(glove_fpath, embd_size, vocab_size, word_index):
embeddings_index = {}
with open(glove_fpath) as f:
for line in f:
values = line.split()
word = values[0]
vector = np.array(values[1:], dtype='float32')
embeddings_index[word] = vector
print('Found {} word vectors in glove.'.format(len(embeddings_index)))
embedding_matrix = np.zeros((vocab_size, embd_size))
print('embed_matrix.shape', embedding_matrix.shape)
found_ct = 0
for word, i in word_index.items():
embedding_vector = embeddings_index.get(word)
# words not found in embedding index will be all-zeros.
if embedding_vector is not None:
embedding_matrix[i] = embedding_vector
found_ct += 1
print('{} words are found in glove'.format(found_ct))
return embedding_matrix
# dummy parameters
embd_size = 200
vocabs = list(sorted(['apple', 'banana', 'lemon']))
w2i = {w: i for i, w in enumerate(vocabs)}
vocab_size = len(vocabs)
print('Loading pre-embedding weights...')
# download the dict from https://nlp.stanford.edu/projects/glove/
g_weigts = load_glove_weights('./dict/glove.twitter.27B.200d.txt', embd_size, vocab_size, w2i)
print('Loaded!')
glove_embd = torch.from_numpy(g_weigts).type(torch.FloatTensor)
# save_pickle(glove_embd, './dict/glove_embd.pickle') # loading weights as a pickle is much faster
pre_embd = glove_embd
m = Model(vocab_size, embd_size, pre_embd)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment