Skip to content

Instantly share code, notes, and snippets.

@RustingSword
Created April 23, 2018 06:09
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RustingSword/f9d78163e96554c6708d96d0a513de25 to your computer and use it in GitHub Desktop.
Save RustingSword/f9d78163e96554c6708d96d0a513de25 to your computer and use it in GitHub Desktop.
Plot TSNE-ed embedding of word vectors generated by fasttext. Require https://github.com/DmitryUlyanov/Multicore-TSNE
#!/usr/bin/env python
#! coding: utf8
from __future__ import print_function
from bokeh.plotting import figure, show, output_file
from bokeh.models import ColumnDataSource, LabelSet, HoverTool
import argparse
import numpy as np
from numpy.linalg import norm
from MulticoreTSNE import MulticoreTSNE as TSNE
parser = argparse.ArgumentParser()
parser.add_argument('-vec', action='store', dest='vectorfile', required=True,
help='Path of model.vec generated by fasttext')
parser.add_argument('-p', action='store', type=int, dest='perplexity',
default=20, help='Perplexity, usually between 20 to 50')
parser.add_argument('-topK', action='store', type=int, dest='topK',
default=500, help='show topK words with largest embedding norm')
def load_data(vectorfile, topK=500):
words = []
with open(vectorfile) as fin:
next(fin) # skip first line containing word_num and embedding_dim
for line in fin:
line = line.strip('\n').split()
words.append((line[0], line[1:]))
print('Loaded {} words, embedding dim = {}'.format( len(words), len(words[0][1])))
words = sorted(words, key=lambda x: norm(x[1]), reverse=True)
labels, embeddings = [], []
for word in words[:topK]:
labels.append(word[0])
embeddings.append(np.array(word[1]))
return labels, np.array(embeddings)
def plot_embedding():
r = parser.parse_args()
for arg in vars(r):
print('%-10s =' % arg, getattr(r, arg))
print('Loading data...')
words, embeddings = load_data(r.vectorfile, r.topK)
tsne = TSNE(n_jobs=4, perplexity=r.perplexity)
Y = tsne.fit_transform(embeddings)
source = ColumnDataSource(data=dict(x=Y[:,0], y=Y[:,1], words=words))
hover = HoverTool(tooltips=[("word", "@words")])
TOOLS = ['pan', 'wheel_zoom', 'reset', hover]
p = figure(title='t-SNE embedding of top ' + str(len(Y)) + ' words',
plot_width=1200, plot_height=720, tools=TOOLS)
p.scatter(x='x', y='y', size=8, source=source)
labels = LabelSet(x='x', y='y', text='words', level='glyph',
x_offset=5, y_offset=5, source=source, render_mode='canvas')
p.add_layout(labels)
show(p)
if __name__ == "__main__":
plot_embedding()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment