Skip to content

Instantly share code, notes, and snippets.

@arnicas
Created July 13, 2019 13:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save arnicas/e78a03ac16433664b40c8e8b3dfc23f7 to your computer and use it in GitHub Desktop.
Save arnicas/e78a03ac16433664b40c8e8b3dfc23f7 to your computer and use it in GitHub Desktop.
create interactive umap embedding display for a word2vec model by simple script
"""
Script originally sourced from Peter Baumgartner
here: https://gist.github.com/pmbaumgartner/adb33aa486b77ab58eb3df265393195d
and then modified by Lynn Cherny to allow a corpus file,
any gensim w2v model file, and make or read a counts file before the
UMAP display.
The counts are used to focus on the most common words, and more
frequent words show as lighter colors in the UMAP display Peter made.
NOTE: Pip install umap-learn not umap; the import method below fixes a bad install/umap issue.
"""
from collections import Counter
import gensim
from numpy import log10
import pandas as pd
import plotly
import plotly.graph_objs as go
import umap.umap_ as umap
# word2vec model - text format here
model = 'gutenberg_fairyfolk_model.txt'
# concated text corpus in one file
corpus = 'all_fairy_books.txt'
# how many words to display in the interactive
count_cutoff = 10000
# optional: will create counts file for you if needed.
# counts_file = None
counts_file = 'all_fairy_books.txt_counts.txt'
output_html_filename = 'w2v-umap-fairy.html'
# want to filter your words?
#stoplist = ["", "of", "the", "in", "a", "an", "to", "with", "is", "was", "as", "for", "that", "which", "and", "And", "have", "be", "from", "or", "are"]
stoplist = []
def make_counts_file(corpus):
wordcounter = Counter()
with open(corpus) as handle:
for line in handle:
words = line.strip("\n").split(" ")
for word in words:
if not (word in stoplist):
wordcounter[word] += 1
filename = corpus + "_counts.txt"
with open(filename, "w") as handle:
for key, value in wordcounter.most_common(count_cutoff):
handle.write(key + "\t" + str(value) + "\n")
return filename
def read_counts_file(r, sep="\t"):
for line in open(r):
yield line.split(sep)
def build_tooltip(row):
full_string = ['<b>Word:</b> ', row['word'],
'<br>',
'<b>Count:</b> ', "{:,}".format((row['count'])),
'<br>',
'<b>Magnitude:</b> ', str(round(row['log_count']))]
return ''.join(full_string)
# should wrap this in a main and pass in args, but:
w2v_model = gensim.models.KeyedVectors.load_word2vec_format(model, binary=False)
vocabulary = set(w2v_model.vocab)
if not counts_file:
counts_file = make_counts_file(corpus)
relevant_words = [(word, count) for (word, count) in read_counts_file(counts_file) if word in vocabulary][:count_cutoff]
model_reduced = w2v_model[[w[0] for w in relevant_words]]
reducer = umap.UMAP(metric='cosine', n_neighbors=15, min_dist=0.05, random_state=42)
embedding = reducer.fit_transform(model_reduced)
d = pd.DataFrame(embedding, columns=['c1', 'c2'])
d['word'] = [w[0] for w in relevant_words]
d['count'] = [int(w[1]) for w in relevant_words]
d['log_count'] = d['count'].apply(log10)
d['tooltip'] = d.apply(build_tooltip, axis=1)
trace = go.Scattergl(
x = d['c1'],
y = d['c2'],
name = 'Embedding',
mode = 'markers',
marker = dict(
color = d['log_count'],
colorscale='Viridis',
size = 6,
line = dict(
width = 0.5,
),
opacity=0.75
),
text=d['tooltip']
)
layout = dict(title = "Word2Vec 2D UMAP Embeddings for " + corpus,
yaxis = dict(zeroline = False),
xaxis = dict(zeroline = False),
hovermode = 'closest'
)
fig = go.Figure(data=[trace], layout=layout)
# Will open in browser and save the file offline:
chart = plotly.offline.plot(fig, filename=output_html_filename)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment