Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Created April 12, 2010 14:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ogrisel/363638 to your computer and use it in GitHub Desktop.
Save ogrisel/363638 to your computer and use it in GitHub Desktop.
t-SNE wrapper to output SVG maps
*.pyc
mnist2500*
build/
pip-log.txt
text-documents/
"""Experimental script to semantically map text/PDF document using t-SNE
See: http://homepage.tudelft.nl/19j49/t-SNE.html
Related: http://github.com/turian/textSNE
You will need:
- lxml
- numpy
- nltk
- http://pypi.python.org/pypi/svg.charts/
(fix the setup.py to point to the correct readme file).
This file is release under the MIT license but the t-SNE code is for research
only usage (this the authors page for details).
"""
import tsne
from nltk.stem.porter import PorterStemmer
from nltk import word_tokenize
from nltk import sent_tokenize
from svg.charts.plot import Plot
import cssutils
import pkg_resources
import numpy as np
import os
from itertools import izip
from lxml import etree
def extract_text(pdf_folder, txt_folder):
if not os.path.exists(txt_folder):
os.makedirs(txt_folder)
for pdf_filename in os.listdir(pdf_folder):
basename, ext = os.path.splitext(pdf_filename)
if ext.lower() != ".pdf":
print "skipping", pdf_filename
continue
pdf_filepath = os.path.join(pdf_folder, pdf_filename)
text_filepath = os.path.join(txt_folder, basename + ".txt")
cmd = "pdftotext %s %s > /dev/null 2>&1" % (pdf_filepath, text_filepath)
print cmd
os.system(cmd)
class HashingVectorizer(object):
"""Compute term frequencies vectors using hashed term space"""
def __init__(self, dim=5000, stemmer=None, probes=3):
self.dim = dim
self.probes = probes
self.stemmer = stemmer if stemmer is not None else PorterStemmer()
def hash(self, term, probe=0):
h = hash(self.stemmer.stem(term.lower()))
return abs(hash(term) + hash(probe * " ")) % self.dim
def term_frequencies(self, files):
"""Tokenize documents and hash the terms and compute the term freqs"""
if isinstance(files, basestring):
folder = files
files = os.listdir(folder)
files.sort()
files = [os.path.join(folder, f) for f in files]
freqs = np.zeros((len(files), self.dim))
for i, filepath in enumerate(files):
print "analysing file %d/%d: %s" % (i + 1, len(files), filepath)
sentences = sent_tokenize(file(filepath).read())
for sentence in sentences:
for term in word_tokenize(sentence):
# TODO add support for cooccurence tokens in a sentence
# window
for probe in xrange(self.probes):
freqs[i][self.hash(term, probe)] += 1.0
freqs[i] /= freqs[i].sum()
return freqs
class SemanticMap(Plot):
"""2D vector map of documents projected using the t-SNE algorithm"""
draw_lines_between_points = False
show_x_guidelines = False
show_y_guidelines = False
show_x_title = False
show_y_title = False
show_x_labels = False
show_y_labels = False
show_label_popup = True
width = 1280
height = 1280
def __init__(self, positions, text_labels=None, categories=None, urls=None):
self.x = positions[:, 0]
self.y = positions[:, 1]
self.positions = positions
dx = self.x.max() - self.x.min()
dy = self.y.max() - self.y.min()
super(SemanticMap, self).__init__({
'min_x_value': self.x.min() - 0.03 * dx,
'min_y_value': self.y.min() - 0.03 * dy,
'max_x_value': self.x.max(),
'max_y_value': self.y.max(),
})
# split data by category
if categories is None:
categories = ["Default category"] * len(positions)
if text_labels is None:
text_labels = [None] * len(positions)
if urls is None:
urls = [None] * len(positions)
categorized = {}
self.point_labels = {}
self.urls = {}
for pos, category, label, url in izip(positions, categories,
text_labels, urls):
categorized.setdefault(category, []).append(pos[0])
categorized.setdefault(category, []).append(pos[1])
self.point_labels[tuple(pos)] = label
self.urls[tuple(pos)] = url
# add the data to the grap
unique_categories = categorized.keys()
unique_categories.sort()
for category in unique_categories:
self.add_data({'data': categorized[category], 'title': category})
def draw_data_points(self, line, data_points, graph_points):
if not self.show_data_points \
and not self.show_data_values: return
for ((dx,dy),(gx,gy)) in izip(data_points, graph_points):
if self.show_data_points:
etree.SubElement(self.graph, 'circle', {
'cx': str(gx),
'cy': str(gy),
'r': '3',
'class': 'dataPoint%(line)s' % vars()})
if self.show_label_popup and self.point_labels[(dx, dy)]:
self.add_popup(gx, gy, self.point_labels[(dx, dy)],
self.urls[(dx, dy)])
def add_popup(self, x, y, label, url):
"Adds pop-up point information to a graph."
txt_width = len(label) * self.font_size * 0.6 + 10
tx = x + [10, -10][int(x + txt_width > self.width)]
anchor = ['start', 'end'][x + txt_width > self.width]
style = 'fill: #000; text-anchor: %s;' % anchor
id = 'label-%s-%s' % (x, y)
t = etree.SubElement(self.foreground, 'text', {
'x': str(tx),
'y': str(y - self.font_size),
'class': 'dataPointLabel',
'visibility': 'hidden',
'style': style,
'id': id
})
t.text = label
# Note, prior to the etree conversion, this circle element was never
# added to anything (now it's added to the foreground)
visibility = "document.getElementById('%s').setAttribute('visibility', '%%s')" % id
if url is not None:
parent = etree.SubElement(self.foreground, 'a', {
'{http://www.w3.org/1999/xlink}href': url,
})
else:
parent = self.foreground
t = etree.SubElement(parent, 'circle', {
'cx': str(x),
'cy': str(y),
'r': '5',
'style': 'opacity: 0;',
'onmouseover': visibility % 'visible',
'onmouseout': visibility % 'hidden',
})
def draw_graph(self):
"""Simple background without axis"""
transform = 'translate (%s %s)' % (self.border_left, self.border_top)
self.graph = etree.SubElement(self.root, 'g', transform=transform)
etree.SubElement(self.graph, 'rect', {
'x': '0',
'y': '0',
'width': str(self.graph_width),
'height': str(self.graph_height),
'class': 'graphBackground'
})
@staticmethod
def load_resource_stylesheet(name, subs=dict()):
css_stream = pkg_resources.resource_stream('doc_tsne', name)
css_string = css_stream.read()
css_string = css_string % subs
sheet = cssutils.parseString(css_string)
return sheet
if __name__ == "__main__":
folder = "/home/ogrisel/Desktop/arxiv-txt"
files = os.listdir(folder)
files.sort()
filepaths = [os.path.join(folder, f) for f in files]
labels = [f.rsplit("-", 1)[0] for f in files]
int_labels = [abs(hash(l)) for l in labels]
#freqs = HashingVectorizer().term_frequencies(filepaths)
#projected = tsne.tsne(freqs, perplexity=20.0)
#m = SemanticMap(projected, categories=labels)
#file(r'/tmp/out.svg', 'w').write(m).burn())
#plt.scatter(projected[:,0], projected[:,1], 20.0, int_labels)
/*
$Id: graph.css 81 2009-09-01 02:04:44Z jaraco $
Base styles for svg.charts.Graph
*/
.svgBackground{
fill:#ffffff;
}
.graphBackground{
fill:#f0f0f0;
}
/* graphs titles */
.mainTitle{
text-anchor: middle;
fill: #000000;
font-size: %(title_font_size)dpx;
font-family: "Arial", sans-serif;
font-weight: normal;
}
.subTitle{
text-anchor: middle;
fill: #999999;
font-size: %(subtitle_font_size)dpx;
font-family: "Arial", sans-serif;
font-weight: normal;
}
.axis{
stroke: #000000;
stroke-width: 1px;
}
.guideLines{
stroke: #666666;
stroke-width: 1px;
stroke-dasharray: 5,5;
}
.xAxisLabels{
text-anchor: middle;
fill: #000000;
font-size: %(x_label_font_size)dpx;
font-family: "Arial", sans-serif;
font-weight: normal;
}
.yAxisLabels{
text-anchor: end;
fill: #000000;
font-size: %(y_label_font_size)dpx;
font-family: "Arial", sans-serif;
font-weight: normal;
}
.xAxisTitle{
text-anchor: middle;
fill: #ff0000;
font-size: %(x_title_font_size)dpx;
font-family: "Arial", sans-serif;
font-weight: normal;
}
.yAxisTitle{
fill: #ff0000;
text-anchor: middle;
font-size: %(y_title_font_size)dpx;
font-family: "Arial", sans-serif;
font-weight: normal;
}
.dataPointLabel{
fill: #000000;
text-anchor:middle;
font-size: 10px;
font-family: "Arial", sans-serif;
font-weight: normal;
}
.staggerGuideLine{
fill: none;
stroke: #000000;
stroke-width: 0.5px;
}
.keyText{
fill: #000000;
text-anchor:start;
font-size: %(key_font_size)dpx;
font-family: "Arial", sans-serif;
font-weight: normal;
}
/*
$Id: plot.css 81 2009-09-01 02:04:44Z jaraco $
default line styles
*/
.line1{
fill: none;
stroke: #ff0000;
stroke-width: 1px;
}
.line2{
fill: none;
stroke: #0000ff;
stroke-width: 1px;
}
.line3{
fill: none;
stroke: #00ff00;
stroke-width: 1px;
}
.line4{
fill: none;
stroke: #ffcc00;
stroke-width: 1px;
}
.line5{
fill: none;
stroke: #00ccff;
stroke-width: 1px;
}
.line6{
fill: none;
stroke: #ff00ff;
stroke-width: 1px;
}
.line7{
fill: none;
stroke: #00ffff;
stroke-width: 1px;
}
.line8{
fill: none;
stroke: #ffff00;
stroke-width: 1px;
}
.line9{
fill: none;
stroke: #cc6666;
stroke-width: 1px;
}
.line10{
fill: none;
stroke: #663399;
stroke-width: 1px;
}
.line11{
fill: none;
stroke: #339900;
stroke-width: 1px;
}
.line12{
fill: none;
stroke: #9966FF;
stroke-width: 1px;
}
/* default fill styles */
.fill1{
fill: #cc0000;
fill-opacity: 0.2;
stroke: none;
}
.fill2{
fill: #0000cc;
fill-opacity: 0.2;
stroke: none;
}
.fill3{
fill: #00cc00;
fill-opacity: 0.2;
stroke: none;
}
.fill4{
fill: #ffcc00;
fill-opacity: 0.2;
stroke: none;
}
.fill5{
fill: #00ccff;
fill-opacity: 0.2;
stroke: none;
}
.fill6{
fill: #ff00ff;
fill-opacity: 0.2;
stroke: none;
}
.fill7{
fill: #00ffff;
fill-opacity: 0.2;
stroke: none;
}
.fill8{
fill: #ffff00;
fill-opacity: 0.2;
stroke: none;
}
.fill9{
fill: #cc6666;
fill-opacity: 0.2;
stroke: none;
}
.fill10{
fill: #663399;
fill-opacity: 0.2;
stroke: none;
}
.fill11{
fill: #339900;
fill-opacity: 0.2;
stroke: none;
}
.fill12{
fill: #9966FF;
fill-opacity: 0.2;
stroke: none;
}
/* default line styles */
.key1,.dataPoint1{
fill: #ff0000;
stroke: none;
stroke-width: 1px;
}
.key2,.dataPoint2{
fill: #0000ff;
stroke: none;
stroke-width: 1px;
}
.key3,.dataPoint3{
fill: #00ff00;
stroke: none;
stroke-width: 1px;
}
.key4,.dataPoint4{
fill: #ffcc00;
stroke: none;
stroke-width: 1px;
}
.key5,.dataPoint5{
fill: #00ccff;
stroke: none;
stroke-width: 1px;
}
.key6,.dataPoint6{
fill: #ff00ff;
stroke: none;
stroke-width: 1px;
}
.key7,.dataPoint7{
fill: #00ffff;
stroke: none;
stroke-width: 1px;
}
.key8,.dataPoint8{
fill: #ffff00;
stroke: none;
stroke-width: 1px;
}
.key9,.dataPoint9{
fill: #cc6666;
stroke: none;
stroke-width: 1px;
}
.key10,.dataPoint10{
fill: #663399;
stroke: none;
stroke-width: 1px;
}
.key11,.dataPoint11{
fill: #ff0000;
stroke: none;
stroke-width: 1px;
}
.key12,.dataPoint12{
fill: #0000ff;
stroke: none;
stroke-width: 1px;
}
.key13,.dataPoint13{
fill: #00ff00;
stroke: none;
stroke-width: 1px;
}
.key14,.dataPoint14{
fill: #ffcc00;
stroke: none;
stroke-width: 1px;
}
.key15,.dataPoint15{
fill: #00ccff;
stroke: none;
stroke-width: 1px;
}
.key16,.dataPoint16{
fill: #ff00ff;
stroke: none;
stroke-width: 1px;
}
.key17,.dataPoint17{
fill: #00ffff;
stroke: none;
stroke-width: 1px;
}
.key18,.dataPoint18{
fill: #ffff00;
stroke: none;
stroke-width: 1px;
}
.key19,.dataPoint19{
fill: #cc6666;
stroke: none;
stroke-width: 1px;
}
.key20,.dataPoint20{
fill: #663399;
stroke: none;
stroke-width: 1px;
}
.key21,.dataPoint21{
fill: #ff0000;
stroke: none;
stroke-width: 1px;
}
.key22,.dataPoint22{
fill: #0000ff;
stroke: none;
stroke-width: 1px;
}
.key23,.dataPoint23{
fill: #00ff00;
stroke: none;
stroke-width: 1px;
}
.key24,.dataPoint24{
fill: #ffcc00;
stroke: none;
stroke-width: 1px;
}
.key25,.dataPoint25{
fill: #00ccff;
stroke: none;
stroke-width: 1px;
}
.key26,.dataPoint26{
fill: #ff00ff;
stroke: none;
stroke-width: 1px;
}
.key27,.dataPoint27{
fill: #00ffff;
stroke: none;
stroke-width: 1px;
}
.key28,.dataPoint28{
fill: #ffff00;
stroke: none;
stroke-width: 1px;
}
.key29,.dataPoint29{
fill: #cc6666;
stroke: none;
stroke-width: 1px;
}
.key30,.dataPoint30{
fill: #663399;
stroke: none;
stroke-width: 1px;
}
.key31,.dataPoint31{
fill: #ff0000;
stroke: none;
stroke-width: 1px;
}
.key32,.dataPoint32{
fill: #0000ff;
stroke: none;
stroke-width: 1px;
}
.key33,.dataPoint33{
fill: #00ff00;
stroke: none;
stroke-width: 1px;
}
.key34,.dataPoint34{
fill: #ffcc00;
stroke: none;
stroke-width: 1px;
}
.key35,.dataPoint35{
fill: #00ccff;
stroke: none;
stroke-width: 1px;
}
.key36,.dataPoint36{
fill: #ff00ff;
stroke: none;
stroke-width: 1px;
}
.key37,.dataPoint37{
fill: #00ffff;
stroke: none;
stroke-width: 1px;
}
.key38,.dataPoint38{
fill: #ffff00;
stroke: none;
stroke-width: 1px;
}
.key39,.dataPoint39{
fill: #cc6666;
stroke: none;
stroke-width: 1px;
}
.key40,.dataPoint40{
fill: #663399;
stroke: none;
stroke-width: 1px;
}
.key41,.dataPoint41{
fill: #ff0000;
stroke: none;
stroke-width: 1px;
}
.key42,.dataPoint42{
fill: #0000ff;
stroke: none;
stroke-width: 1px;
}
.key43,.dataPoint43{
fill: #00ff00;
stroke: none;
stroke-width: 1px;
}
.key44,.dataPoint44{
fill: #ffcc00;
stroke: none;
stroke-width: 1px;
}
.key45,.dataPoint45{
fill: #00ccff;
stroke: none;
stroke-width: 1px;
}
.key46,.dataPoint46{
fill: #ff00ff;
stroke: none;
stroke-width: 1px;
}
.key47,.dataPoint47{
fill: #00ffff;
stroke: none;
stroke-width: 1px;
}
.key48,.dataPoint48{
fill: #ffff00;
stroke: none;
stroke-width: 1px;
}
.key49,.dataPoint49{
fill: #cc6666;
stroke: none;
stroke-width: 1px;
}
.constantLine{
color: navy;
stroke: navy;
stroke-width: 1px;
stroke-dasharray: 9,1,1;
}
#
# tsne.py
#
# Implementation of t-SNE in Python. The implementation was tested on Python 2.5.1, and it requires a working
# installation of NumPy. The implementation comes with an example on the MNIST dataset. In order to plot the
# results of this example, a working installation of matplotlib is required.
# The example can be run by executing: ipython tsne.py -pylab
#
#
# Created by Laurens van der Maaten on 20-12-08.
# Copyright (c) 2008 Tilburg University. All rights reserved.
#
# Modified by Joseph Turian:
# * Use psyco if available.
# * Added parameter use_pca, with default False. NB this changes the default behavior.
# TODO:
# * Make tsne.pca == calc_tsne.PCA
# Modified by Olivier Grisel:
# * Make it possible to ctrl-C to early stop
# * Cosmits
#
import numpy as Math
import pylab as Plot
import sys
try:
import psyco
psyco.full()
print >> sys.stderr, "psyco is usable!"
except:
print >> sys.stderr, "No psyco"
def Hbeta(D = Math.array([]), beta = 1.0):
"""Compute the perplexity and the P-row for a specific value of the precision of a Gaussian distribution."""
# Compute P-row and corresponding perplexity
P = Math.exp(-D.copy() * beta)
sumP = sum(P)
H = Math.log(sumP) + beta * Math.sum(D * P) / sumP
P = P / sumP
return H, P
def x2p(X = Math.array([]), tol = 1e-5, perplexity = 30.0):
"""Performs a binary search to get P-values in such a way that each conditional Gaussian has the same perplexity."""
# Initialize some variables
print "Computing pairwise distances..."
(n, d) = X.shape
sum_X = Math.sum(Math.square(X), 1)
D = Math.add(Math.add(-2 * Math.dot(X, X.T), sum_X).T, sum_X)
P = Math.zeros((n, n))
beta = Math.ones((n, 1))
logU = Math.log(perplexity)
# Loop over all datapoints
for i in range(n):
# Print progress
if i % 500 == 0:
print "Computing P-values for point ", i, " of ", n, "..."
# Compute the Gaussian kernel and entropy for the current precision
betamin = -Math.inf
betamax = Math.inf
Di = D[i, Math.concatenate((Math.r_[0:i], Math.r_[i+1:n]))]
(H, thisP) = Hbeta(Di, beta[i])
# Evaluate whether the perplexity is within tolerance
Hdiff = H - logU
tries = 0
while Math.abs(Hdiff) > tol and tries < 50:
# If not, increase or decrease precision
if Hdiff > 0:
betamin = beta[i]
if betamax == Math.inf or betamax == -Math.inf:
beta[i] = beta[i] * 2
else:
beta[i] = (beta[i] + betamax) / 2
else:
betamax = beta[i]
if betamin == Math.inf or betamin == -Math.inf:
beta[i] = beta[i] / 2
else:
beta[i] = (beta[i] + betamin) / 2
# Recompute the values
(H, thisP) = Hbeta(Di, beta[i])
Hdiff = H - logU
tries = tries + 1
# Set the final row of P
P[i, Math.concatenate((Math.r_[0:i], Math.r_[i+1:n]))] = thisP
# Return final P-matrix
print "Mean value of sigma: ", Math.mean(Math.sqrt(1 / beta))
return P
def pca(X = Math.array([]), no_dims = 50):
"""Runs PCA on the NxD array X in order to reduce its dimensionality to no_dims dimensions."""
print "Preprocessing the data using PCA..."
(n, d) = X.shape
X = X - Math.tile(Math.mean(X, 0), (n, 1))
(l, M) = Math.linalg.eig(Math.dot(X.T, X))
Y = Math.dot(X, M[:,0:no_dims])
return Y
def tsne(X = Math.array([]), no_dims = 2, initial_dims = 50, perplexity = 30.0, use_pca=False):
"""Runs t-SNE on the dataset in the NxD array X to reduce its dimensionality to no_dims dimensions.
The syntaxis of the function is Y = tsne.tsne(X, no_dims, perplexity), where X is an NxD NumPy array."""
# Check inputs
if X.dtype != "float64":
print "Error: array X should have type float64."
return -1
#if no_dims.__class__ != "<type 'int'>": # doesn't work yet!
# print "Error: number of dimensions should be an integer."
# return -1
# Initialize variables
if use_pca:
X = pca(X, initial_dims)
(n, d) = X.shape
max_iter = 5000
initial_momentum = 0.5
final_momentum = 0.8
eta = 500
min_gain = 0.01
Y = Math.random.randn(n, no_dims)
dY = Math.zeros((n, no_dims))
iY = Math.zeros((n, no_dims))
gains = Math.ones((n, no_dims))
# Compute P-values
P = x2p(X, 1e-5, perplexity)
P = P + Math.transpose(P)
P = P / Math.sum(P)
P = P * 4; # early exaggeration
P = Math.maximum(P, 1e-12)
try:
# Run iterations
for iter in range(max_iter):
# Compute pairwise affinities
sum_Y = Math.sum(Math.square(Y), 1)
num = 1 / (1 + Math.add(Math.add(-2 * Math.dot(Y, Y.T), sum_Y).T, sum_Y))
num[range(n), range(n)] = 0
Q = num / Math.sum(num)
Q = Math.maximum(Q, 1e-12)
# Compute gradient
PQ = P - Q
for i in range(n):
dY[i,:] = Math.sum(Math.tile(PQ[:,i] * num[:,i], (no_dims, 1)).T * (Y[i,:] - Y), 0)
# Perform the update
if iter < 20:
momentum = initial_momentum
else:
momentum = final_momentum
gains = (gains + 0.2) * ((dY > 0) != (iY > 0)) + (gains * 0.8) * ((dY > 0) == (iY > 0))
gains[gains < min_gain] = min_gain
iY = momentum * iY - eta * (gains * dY)
Y = Y + iY
Y = Y - Math.tile(Math.mean(Y, 0), (n, 1))
# Compute current value of cost function
if (iter + 1) % 100 == 0:
C = Math.sum(P * Math.log(P / Q))
print "Iteration ", (iter + 1), ": error is ", C
# Stop lying about P-values
if iter == 100:
P = P / 4
except KeyboardInterrupt:
print >> sys.stderr, "early stopping by user"
# Return solution
return Y
if __name__ == "__main__":
print "Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset."
print "Running example on 2,500 MNIST digits..."
X = Math.loadtxt("mnist2500_X.txt")
labels = Math.loadtxt("mnist2500_labels.txt")
Y = tsne(X, 2, 50, 20.0)
Plot.scatter(Y[:,0], Y[:,1], 20, labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment