Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Created March 30, 2020 15:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrdrozdov/1815cbd096a77f5e10f20479044eeb69 to your computer and use it in GitHub Desktop.
Save mrdrozdov/1815cbd096a77f5e10f20479044eeb69 to your computer and use it in GitHub Desktop.
context_insensitive_word_embeddings.py
import os
import hashlib
from allennlp.commands.elmo import ElmoEmbedder
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper
import numpy as np
def save_elmo_cache(path, vectors):
np.save(path, vectors)
def load_elmo_cache(path):
vectors = np.load(path)
return vectors
def hash_vocab(tokens, version='v1.0.0'):
m = hashlib.sha256()
m.update(str.encode(version))
for w in tokens:
m.update(str.encode(w))
return m.hexdigest()
def context_insensitive_character_embeddings(weights_path, options_path, tokens, cuda=False, cache_dir=None):
if cache_dir is not None:
key = hash_vocab(tokens)
cache_path = os.path.join(cache_dir, 'elmo_{}.npy'.format(key))
if os.path.exists(cache_path):
print('Loading cached elmo vectors: {}'.format(cache_path))
return load_elmo_cache(cache_path)
if cuda:
device = 0
else:
device = -1
batch_size = 256
nbatches = len(tokens) // batch_size + 1
elmo = ElmoEmbedder(options_file=options_path, weight_file=weights_path, cuda_device=device)
assert tokens[0] == ELMoCharacterMapper.bos_token # <S>
assert tokens[1] == ELMoCharacterMapper.eos_token # </S>
assert tokens[2] == '_PAD_'
elmo.elmo_bilm.create_cached_cnn_embeddings(tokens[2:])
bos_vector = elmo.elmo_bilm._bos_embedding.numpy().reshape(1, -1)
eos_vector = elmo.elmo_bilm._eos_embedding.numpy().reshape(1, -1)
word_vectors = elmo.elmo_bilm._word_embedding.weight.numpy()
vectors = np.concatenate([bos_vector, eos_vector, word_vectors])
if cache_dir is not None:
print('Saving cached elmo vectors: {}'.format(cache_path))
save_elmo_cache(cache_path, vectors)
return vectors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment