create interactive umap embedding display for a word2vec model by simple script
""" | |
Script originally sourced from Peter Baumgartner | |
here: https://gist.github.com/pmbaumgartner/adb33aa486b77ab58eb3df265393195d | |
and then modified by Lynn Cherny to allow a corpus file, | |
any gensim w2v model file, and make or read a counts file before the | |
UMAP display. | |
The counts are used to focus on the most common words, and more | |
frequent words show as lighter colors in the UMAP display Peter made. | |
NOTE: Pip install umap-learn not umap; the import method below fixes a bad install/umap issue. | |
""" | |
from collections import Counter | |
import gensim | |
from numpy import log10 | |
import pandas as pd | |
import plotly | |
import plotly.graph_objs as go | |
import umap.umap_ as umap | |
# word2vec model - text format here | |
model = 'gutenberg_fairyfolk_model.txt' | |
# concated text corpus in one file | |
corpus = 'all_fairy_books.txt' | |
# how many words to display in the interactive | |
count_cutoff = 10000 | |
# optional: will create counts file for you if needed. | |
# counts_file = None | |
counts_file = 'all_fairy_books.txt_counts.txt' | |
output_html_filename = 'w2v-umap-fairy.html' | |
# want to filter your words? | |
#stoplist = ["", "of", "the", "in", "a", "an", "to", "with", "is", "was", "as", "for", "that", "which", "and", "And", "have", "be", "from", "or", "are"] | |
stoplist = [] | |
def make_counts_file(corpus): | |
wordcounter = Counter() | |
with open(corpus) as handle: | |
for line in handle: | |
words = line.strip("\n").split(" ") | |
for word in words: | |
if not (word in stoplist): | |
wordcounter[word] += 1 | |
filename = corpus + "_counts.txt" | |
with open(filename, "w") as handle: | |
for key, value in wordcounter.most_common(count_cutoff): | |
handle.write(key + "\t" + str(value) + "\n") | |
return filename | |
def read_counts_file(r, sep="\t"): | |
for line in open(r): | |
yield line.split(sep) | |
def build_tooltip(row): | |
full_string = ['<b>Word:</b> ', row['word'], | |
'<br>', | |
'<b>Count:</b> ', "{:,}".format((row['count'])), | |
'<br>', | |
'<b>Magnitude:</b> ', str(round(row['log_count']))] | |
return ''.join(full_string) | |
# should wrap this in a main and pass in args, but: | |
w2v_model = gensim.models.KeyedVectors.load_word2vec_format(model, binary=False) | |
vocabulary = set(w2v_model.vocab) | |
if not counts_file: | |
counts_file = make_counts_file(corpus) | |
relevant_words = [(word, count) for (word, count) in read_counts_file(counts_file) if word in vocabulary][:count_cutoff] | |
model_reduced = w2v_model[[w[0] for w in relevant_words]] | |
reducer = umap.UMAP(metric='cosine', n_neighbors=15, min_dist=0.05, random_state=42) | |
embedding = reducer.fit_transform(model_reduced) | |
d = pd.DataFrame(embedding, columns=['c1', 'c2']) | |
d['word'] = [w[0] for w in relevant_words] | |
d['count'] = [int(w[1]) for w in relevant_words] | |
d['log_count'] = d['count'].apply(log10) | |
d['tooltip'] = d.apply(build_tooltip, axis=1) | |
trace = go.Scattergl( | |
x = d['c1'], | |
y = d['c2'], | |
name = 'Embedding', | |
mode = 'markers', | |
marker = dict( | |
color = d['log_count'], | |
colorscale='Viridis', | |
size = 6, | |
line = dict( | |
width = 0.5, | |
), | |
opacity=0.75 | |
), | |
text=d['tooltip'] | |
) | |
layout = dict(title = "Word2Vec 2D UMAP Embeddings for " + corpus, | |
yaxis = dict(zeroline = False), | |
xaxis = dict(zeroline = False), | |
hovermode = 'closest' | |
) | |
fig = go.Figure(data=[trace], layout=layout) | |
# Will open in browser and save the file offline: | |
chart = plotly.offline.plot(fig, filename=output_html_filename) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment