Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Function to align any number of word2vec models using Procrustes matrix alignment.
# Code originally ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <wleif@stanford.edu>.
def align_gensim_models(models, words=None):
"""
Returns the aligned/intersected models from a list of gensim word2vec models.
Generalized from original two-way intersection as seen above.
Also updated to work with the most recent version of gensim
Requires reduce from functools
In order to run this, make sure you run 'model.init_sims()' for each model before you input them for alignment.
##############################################
ORIGINAL DESCRIPTION
##############################################
Only the shared vocabulary between them is kept.
If 'words' is set (as list or set), then the vocabulary is intersected with this list as well.
Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2).
These indices correspond to the new syn0 and syn0norm objects in both gensim models:
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0
-- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2
The .vocab dictionary is also updated for each model, preserving the count but updating the index.
"""
# Get the vocab for each model
vocabs = [set(m.wv.vocab.keys()) for m in models]
# Find the common vocabulary
common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs)
if words: common_vocab&=set(words)
# If no alignment necessary because vocab is identical...
# This was generalized from:
# if not vocab_m1-common_vocab and not vocab_m2-common_vocab and not vocab_m3-common_vocab:
# return (m1,m2,m3)
if all(not vocab-common_vocab for vocab in vocabs):
print("All identical!")
return models
# Otherwise sort by frequency (summed for both)
common_vocab = list(common_vocab)
common_vocab.sort(key=lambda w: sum([m.wv.vocab[w].count for m in models]),reverse=True)
# Then for each model...
for m in models:
# Replace old vectors_norm array with new one (with common vocab)
indices = [m.wv.vocab[w].index for w in common_vocab]
old_arr = m.wv.vectors_norm
new_arr = np.array([old_arr[index] for index in indices])
m.wv.vectors_norm = m.wv.syn0 = new_arr
# Replace old vocab dictionary with new one (with common vocab)
# and old index2word with new one
m.wv.index2word = common_vocab
old_vocab = m.wv.vocab
new_vocab = {}
for new_index,word in enumerate(common_vocab):
old_vocab_obj=old_vocab[word]
new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index, count=old_vocab_obj.count)
m.wv.vocab = new_vocab
return models
@DapangLiu

This comment has been minimized.

Copy link

commented Jun 6, 2019

So what is reduce() in line 31?

common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs)

Thank you!

@tangert

This comment has been minimized.

Copy link
Owner Author

commented Jun 7, 2019

@DapangLiu it is taking the intersection of all three vocabulary sets in vocabs. you can read more about reduce here!

basically it goes through each element in the vocabs array and takes the intersection of all of them with the & operator. does that make sense?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.