Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save tangert/106822a0f56f8308db3f1d77be2c7942 to your computer and use it in GitHub Desktop.
Save tangert/106822a0f56f8308db3f1d77be2c7942 to your computer and use it in GitHub Desktop.
Function to align any number of word2vec models using Procrustes matrix alignment.
# Code originally ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <wleif@stanford.edu>.
def align_gensim_models(models, words=None):
"""
Returns the aligned/intersected models from a list of gensim word2vec models.
Generalized from original two-way intersection as seen above.
Also updated to work with the most recent version of gensim
Requires reduce from functools
In order to run this, make sure you run 'model.init_sims()' for each model before you input them for alignment.
##############################################
ORIGINAL DESCRIPTION
##############################################
Only the shared vocabulary between them is kept.
If 'words' is set (as list or set), then the vocabulary is intersected with this list as well.
Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2).
These indices correspond to the new syn0 and syn0norm objects in both gensim models:
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0
-- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2
The .vocab dictionary is also updated for each model, preserving the count but updating the index.
"""
# Get the vocab for each model
vocabs = [set(m.wv.vocab.keys()) for m in models]
# Find the common vocabulary
common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs)
if words: common_vocab&=set(words)
# If no alignment necessary because vocab is identical...
# This was generalized from:
# if not vocab_m1-common_vocab and not vocab_m2-common_vocab and not vocab_m3-common_vocab:
# return (m1,m2,m3)
if all(not vocab-common_vocab for vocab in vocabs):
print("All identical!")
return models
# Otherwise sort by frequency (summed for both)
common_vocab = list(common_vocab)
common_vocab.sort(key=lambda w: sum([m.wv.vocab[w].count for m in models]),reverse=True)
# Then for each model...
for m in models:
# Replace old vectors_norm array with new one (with common vocab)
indices = [m.wv.vocab[w].index for w in common_vocab]
old_arr = m.wv.vectors_norm
new_arr = np.array([old_arr[index] for index in indices])
m.wv.vectors_norm = m.wv.syn0 = new_arr
# Replace old vocab dictionary with new one (with common vocab)
# and old index2word with new one
m.wv.index2word = common_vocab
old_vocab = m.wv.vocab
new_vocab = {}
for new_index,word in enumerate(common_vocab):
old_vocab_obj=old_vocab[word]
new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index, count=old_vocab_obj.count)
m.wv.vocab = new_vocab
return models
@shensimeteor
Copy link

Thanks for sharing. It seems the code just did intersection, but not the alignment. Did I miss anything?

@poornagurram
Copy link

@shensimeteor the alignment here is the intersection itself. Getting all the models into a single common vocab.

@shikharsingla
Copy link

shikharsingla commented Apr 9, 2021

https://gist.github.com/louridas/a3cdb1b109ac03a8e202f4b19c9335b3 used linalg svd from numpy to do the procrustes. Where is that part in your code? sorry if I missed something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment