Reading Glove Embeddings
import numpy as np | |
class Embedding(object): | |
def __init__(self, unk_token=None): | |
self.unk_token = unk_token | |
self.word2id = {unk_token: 0} | |
self.id2word = [unk_token] | |
self.vectors = [] | |
def __len__(self): | |
return len(self.word2id) | |
def __setitem__(self, word, vector): | |
if word not in self.word2id: | |
self.word2id[word] = len(self) | |
self.id2word.append(word) | |
self.vectors.append(vector) | |
def __getitem__(self, word): | |
word_id = self.word2id.get(word, 0) | |
return self.vectors[word_id] | |
def __contains__(self, word): | |
return word in self.word2id | |
def words_to_ids(self, words): | |
ids = [self.word2id.get(w, 0) for w in words] | |
return ids | |
def words_to_vectors(self, words): | |
ids = self.words_to_ids(words) | |
vectors = self.vectors[ids] | |
return vectors | |
def vectors_to_numpy(self): | |
self.vectors = np.zeros_like(self.vectors[:1]).tolist() + self.vectors | |
self.vectors = np.asarray(self.vectors) | |
assert len(self.vectors.shape) == 2, "Some error occured in vectorization" | |
def read_embeddings(path, all_vocab): | |
embedding = Embedding() | |
with open(path) as fp: | |
for line in fp: | |
line = line.split("\t") | |
assert len(line) > 1, f"Incorrect line: {line}" | |
word = line[0] | |
vector = [float(x) for x in line[1:]] | |
if word in all_vocab: | |
embedding[word] = vector | |
if len(embedding) > len(all_vocab): | |
break | |
embedding.vectors_to_numpy() | |
return embedding |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment