Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save tangert/106822a0f56f8308db3f1d77be2c7942 to your computer and use it in GitHub Desktop.
Save tangert/106822a0f56f8308db3f1d77be2c7942 to your computer and use it in GitHub Desktop.
Function to align any number of word2vec models using Procrustes matrix alignment.
# Code originally ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <wleif@stanford.edu>.
def align_gensim_models(models, words=None):
"""
Returns the aligned/intersected models from a list of gensim word2vec models.
Generalized from original two-way intersection as seen above.
Also updated to work with the most recent version of gensim
Requires reduce from functools
In order to run this, make sure you run 'model.init_sims()' for each model before you input them for alignment.
##############################################
ORIGINAL DESCRIPTION
##############################################
Only the shared vocabulary between them is kept.
If 'words' is set (as list or set), then the vocabulary is intersected with this list as well.
Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2).
These indices correspond to the new syn0 and syn0norm objects in both gensim models:
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0
-- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2
The .vocab dictionary is also updated for each model, preserving the count but updating the index.
"""
# Get the vocab for each model
vocabs = [set(m.wv.vocab.keys()) for m in models]
# Find the common vocabulary
common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs)
if words: common_vocab&=set(words)
# If no alignment necessary because vocab is identical...
# This was generalized from:
# if not vocab_m1-common_vocab and not vocab_m2-common_vocab and not vocab_m3-common_vocab:
# return (m1,m2,m3)
if all(not vocab-common_vocab for vocab in vocabs):
print("All identical!")
return models
# Otherwise sort by frequency (summed for both)
common_vocab = list(common_vocab)
common_vocab.sort(key=lambda w: sum([m.wv.vocab[w].count for m in models]),reverse=True)
# Then for each model...
for m in models:
# Replace old vectors_norm array with new one (with common vocab)
indices = [m.wv.vocab[w].index for w in common_vocab]
old_arr = m.wv.vectors_norm
new_arr = np.array([old_arr[index] for index in indices])
m.wv.vectors_norm = m.wv.syn0 = new_arr
# Replace old vocab dictionary with new one (with common vocab)
# and old index2word with new one
m.wv.index2word = common_vocab
old_vocab = m.wv.vocab
new_vocab = {}
for new_index,word in enumerate(common_vocab):
old_vocab_obj=old_vocab[word]
new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index, count=old_vocab_obj.count)
m.wv.vocab = new_vocab
return models
@agi-templar
Copy link

So what is reduce() in line 31?

common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs)

Thank you!

@tangert
Copy link
Author

tangert commented Jun 7, 2019

@DapangLiu it is taking the intersection of all three vocabulary sets in vocabs. you can read more about reduce here!

basically it goes through each element in the vocabs array and takes the intersection of all of them with the & operator. does that make sense?

@shensimeteor
Copy link

Thanks for sharing. It seems the code just did intersection, but not the alignment. Did I miss anything?

@poornagurram
Copy link

@shensimeteor the alignment here is the intersection itself. Getting all the models into a single common vocab.

@shikharsingla
Copy link

shikharsingla commented Apr 9, 2021

https://gist.github.com/louridas/a3cdb1b109ac03a8e202f4b19c9335b3 used linalg svd from numpy to do the procrustes. Where is that part in your code? sorry if I missed something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment