Skip to content

Instantly share code, notes, and snippets.

@hans
Last active March 26, 2018 18:50
Show Gist options
  • Save hans/46788eed5669313c54f4 to your computer and use it in GitHub Desktop.
Save hans/46788eed5669313c54f4 to your computer and use it in GitHub Desktop.
Generate embeddings for rare words in a document by averaging the embeddings of associated context words. Find nearest neighbors of these embeddings to evaluate their quality.
from collections import Counter, defaultdict
import itertools
import os
import random
import re
import numpy as np
EMBEDDING_FILE = "/u/nlp/data/depparser/nn/data/embeddings/en-cw.txt"
EMBEDDING_SERIALIZED = "embeddings.npz"
STOPWORDS_FILE = "/u/nlp/data/gaz/stopwords"
#CORPUS_FILE = "GENIA.raw.form.txt"
CORPUS_FILE = "train-wsj-0-18.raw"
PUNCT_REGEX = re.compile(r"^[[\]()-.,]+$")
def load_embeddings(f_stream):
words = []
embs = []
for line in f_stream:
fields = line.strip().split()
word = fields[0]
vals = [float(x) for x in fields[1:]]
words.append(word)
embs.append(np.array(vals))
dict = {word: i for i, word in enumerate(words)}
return dict, np.array(embs)
def load_corpus(f_stream):
sentences = []
freqs = Counter()
for line in f_stream:
tokens = line.strip().replace('.', ' .').replace(',', ' ,').split()
for token in tokens:
freqs[token] += 1
sentences.append(tokens)
return sentences, freqs
def load_corpus_with_contexts(f_stream, window=5, stopwords=frozenset()):
sentences = []
freqs = Counter()
contexts = defaultdict(list)
for line in f_stream:
tokens = line.strip().replace('.', ' .').replace(',', ' ,').split()
for i, token in enumerate(tokens):
freqs[token] += 1
context = [ctx_word for ctx_word in tokens[max(0, i - window):i + window]
if ctx_word != token and ctx_word not in stopwords and not PUNCT_REGEX.match(ctx_word)]
contexts[token].extend(context)
sentences.append(tokens)
return sentences, freqs, contexts
def get_contexts(word, sentences, window=5, stopwords=frozenset()):
context_words = []
for sentence in sentences:
idxs = [i for i, other_word in enumerate(sentence) if word == other_word]
for idx in idxs:
idx_context = [ctx_word for ctx_word in sentence[max(0, idx - window):idx + window]
if ctx_word != word and ctx_word not in stopwords and not PUNCT_REGEX.match(ctx_word)]
context_words.extend(idx_context)
return context_words
def avg_word_embeddings(words, dict, embs):
avg = np.zeros(embs.shape[1])
for word in words:
word = word.lower()
try:
avg += embs[dict[word.lower()]]
except KeyError: pass
avg /= float(len(words))
return avg
def nearest_neighbors(x, embeddings, n=5):
x /= np.linalg.norm(x)
dists = -embeddings.dot(x)
return np.argsort(dists)[:n]
def nearest_neighbor_words(x, dict, embs, rev_dict=None, n=5):
if rev_dict is None:
rev_dict = {v: k for k, v in dict.iteritems()}
return [rev_dict[id] for id in nearest_neighbors(x, embs)]
def averaged_nn_simple(word_list, dict, embs, stopwords=frozenset()):
"""Generate embeddings for provided words by averaging embeddings
of context words. Yields pairs of the form
(word, freq, embedding)
"""
with open(CORPUS_FILE, 'r') as corpus_f:
sentences, word_freqs = list(load_corpus(corpus_f))
if isinstance(word_list, int):
word_list = random.sample(word_freqs.keys(), word_list)
for word in word_list:
ctxs = get_contexts(word, sentences, stopwords=stopwords)
yield word, word_freqs[word], avg_word_embeddings(ctxs, dict, embs)
def averaged_nn_nn(word_list, dict, embs, freq_threshold=10, window=5, stopwords=frozenset(), rev_dict=None):
"""Generate embeddings for provided words by picking word embeddings
that match the context embeddings of the provided words. Yields pairs
of the form
(word, freq, embedding)
"""
with open(CORPUS_FILE, 'r') as corpus_f:
sentences, word_freqs, contexts = load_corpus_with_contexts(corpus_f, window=window,
stopwords=stopwords)
if isinstance(word_list, int):
word_list = random.sample(word_freqs.keys(), word_list)
if rev_dict is None:
rev_dict = {v: k for k, v in dict.iteritems()}
# Compute context embeddings for all words in the dictionary
context_embeddings = np.array([avg_word_embeddings(ctxs, dict, embs)
for word, ctxs in contexts.iteritems()])
# Normalize context embeddings
context_embeddings /= np.linalg.norm(context_embeddings, axis=1)[:, np.newaxis]
for word in word_list:
# Compute context embedding of unknown word
ctxs = contexts[word]
unk_context_embedding = avg_word_embeddings(ctxs, dict, embs)
unk_context_embedding /= np.linalg.norm(unk_context_embedding)
# Find distance to context words
dists = -context_embeddings.dot(unk_context_embedding)
# Find context embedding and associated word (with word freq threshold)
best_idx = next(idx for idx in dists.argsort() if word_freqs[rev_dict[idx]] > freq_threshold)
# OK, return word embedding associated with this context embedding
emb = embs[best_idx]
yield word, word_freqs[word], emb
if __name__ == '__main__':
if os.path.exists(EMBEDDING_SERIALIZED):
ser = np.load(EMBEDDING_SERIALIZED)
dict, embs = ser['dict'][()], ser['embs']
else:
with open(EMBEDDING_FILE, 'r') as emb_f:
dict, embs = load_embeddings(emb_f)
# Normalize embeddings
embs /= np.linalg.norm(embs, axis=1)[:, np.newaxis]
np.savez(EMBEDDING_SERIALIZED, dict=dict, embs=embs)
print 'Loaded %i embeddings' % len(dict)
with open(STOPWORDS_FILE, 'r') as stopwords_f:
stopwords = frozenset([x.strip() for x in stopwords_f.readlines()])
# Number of words to fetch per trial
n_words = 20
# Number of neighbors to list
n_list = 7
dim = embs.shape[1]
rev_dict = {v: k for k, v in dict.iteritems()}
nn_simple_results = list(averaged_nn_simple(n_words, dict, embs, stopwords=stopwords))
word_list = [word for word, _, _ in nn_simple_results]
nn_nn_results = list(averaged_nn_nn(word_list, dict, embs, stopwords=stopwords))
for nn_simple_result, nn_nn_result in zip(nn_simple_results, nn_nn_results):
word, word_freq, simple_emb = nn_simple_result
_, _, nn_emb = nn_nn_result
# Print nearest neighbors of simple averaged embedding
neighbors = nearest_neighbor_words(simple_emb, dict, embs, rev_dict=rev_dict, n=n_list)
print '%30s %5d\t%s' % (word, word_freq, ' '.join(neighbors))
# Print nearest neighbors of NN averaged embedding
nn_neighbors = nearest_neighbor_words(nn_emb, dict, embs, rev_dict=rev_dict, n=n_list)
print '%30s %5d\t%s' % (word, word_freq, ' '.join(nn_neighbors))
# Print nearest neighbors of a random embedding
rand_neighbors = nearest_neighbor_words(np.random.rand(dim), dict, embs, rev_dict=rev_dict, n=n_list)
print '%30s %5s\t%s' % ('', '', ' '.join(rand_neighbors))
"""
Example output (run on PTB WSJ training data):
$ python avg_embeddings.py
Loaded 130000 embeddings
TRIMMING 1 ? happy jealous speechless depressed
mid-week windies spoils chanderpaul roode
Aug 41 cost reduced budget election vote
yankees metrodome canucks waratahs stormers
cultivating 1 a head easy constant power
cannot pre-election dampens mid-atlantic city-state
Knopf 4 book royal writing title reading
could seems appreciates dugouts seemed
Symbol 6 otc ; : ystem aggregator
dugouts brink mid-week hopes trumper
loyalty 28 brand market share potential buyer
dugouts mid-week persuasions deal-making post-cold
Hurwitz 1 president chief david general james
deserves appreciates emphasizes dictates materialises
plunging 7 sales money payment sale prices
pessimists nby clamour conferees doubters
intimidation 1 bombings killings arrests campaign assassination
saw emphasizes beyond exceeds ensures
Kelly/Varnell 1 west south east left come
dampens long-awaited post-war mid-week cross-media
police 59 security police emergency service aid
dampens chanderpaul cannot downplays execs
cost-efficiency 1 productivity flexibility growth consumption interdependence
mid-week dugouts blacklist dampens heightens
worriers 2 repressed refuting discussed heterodox applauded
bolls acpc clamour requirments fall-out
Money-fund 3 wages revenues assets incomes receipts
olein dampens stearin bradies bm&f
cocky 1 denying depriving accusing eliminating accepting
hurts dampens downplays scotched defers
Laurel 8 president bank agent governor party
deserves deserved downplays resounding exuded
minor-sport 1 games scores drivers students points
mid-week littlejohn dampens long-awaited dugouts
graphic 1 firms entrepreneurs photographers economists scientists
appreciates promises blacklist hustle imbroglio
campaign 108 hearing news special advertising public
mid-week reckons conferees loog dugouts
Manhattan 70 council center district school law
deserve deserves dampens enduring clear-cut
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment