Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Function to align any number of word2vec models using Procrustes matrix alignment.
# Code originally ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <wleif@stanford.edu>.
def align_gensim_models(models, words=None):
"""
Returns the aligned/intersected models from a list of gensim word2vec models.
Generalized from original two-way intersection as seen above.
Also updated to work with the most recent version of gensim
Requires reduce from functools
In order to run this, make sure you run 'model.init_sims()' for each model before you input them for alignment.
##############################################
ORIGINAL DESCRIPTION
##############################################
Only the shared vocabulary between them is kept.
If 'words' is set (as list or set), then the vocabulary is intersected with this list as well.
Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2).
These indices correspond to the new syn0 and syn0norm objects in both gensim models:
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0
-- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2
The .vocab dictionary is also updated for each model, preserving the count but updating the index.
"""
# Get the vocab for each model
vocabs = [set(m.wv.vocab.keys()) for m in models]
# Find the common vocabulary
common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs)
if words: common_vocab&=set(words)
# If no alignment necessary because vocab is identical...
# This was generalized from:
# if not vocab_m1-common_vocab and not vocab_m2-common_vocab and not vocab_m3-common_vocab:
# return (m1,m2,m3)
if all(not vocab-common_vocab for vocab in vocabs):
print("All identical!")
return models
# Otherwise sort by frequency (summed for both)
common_vocab = list(common_vocab)
common_vocab.sort(key=lambda w: sum([m.wv.vocab[w].count for m in models]),reverse=True)
# Then for each model...
for m in models:
# Replace old vectors_norm array with new one (with common vocab)
indices = [m.wv.vocab[w].index for w in common_vocab]
old_arr = m.wv.vectors_norm
new_arr = np.array([old_arr[index] for index in indices])
m.wv.vectors_norm = m.wv.syn0 = new_arr
# Replace old vocab dictionary with new one (with common vocab)
# and old index2word with new one
m.wv.index2word = common_vocab
old_vocab = m.wv.vocab
new_vocab = {}
for new_index,word in enumerate(common_vocab):
old_vocab_obj=old_vocab[word]
new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index, count=old_vocab_obj.count)
m.wv.vocab = new_vocab
return models
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.