Skip to content

Instantly share code, notes, and snippets.

@dineshj1
Last active March 28, 2018 04:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save dineshj1/627d766b935a1b88d75c8588e5790570 to your computer and use it in GitHub Desktop.
Save dineshj1/627d766b935a1b88d75c8588e5790570 to your computer and use it in GitHub Desktop.
t-SNE visualization code
# Dinesh Jayaraman
# Based on code by
# Authors: Fabian Pedregosa <fabian.pedregosa@inria.fr>
# Olivier Grisel <olivier.grisel@ensta.org>
# Mathieu Blondel <mathieu@mblondel.org>
# Gael Varoquaux
# License: BSD 3 clause (C) INRIA 2011
print(__doc__)
from time import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import offsetbox
from matplotlib.patches import Rectangle
from sklearn import (manifold, datasets, decomposition, ensemble,
discriminant_analysis, random_projection)
import scipy.misc
import matplotlib.patches as mpatches
imresize = scipy.misc.imresize
from collections import namedtuple
import h5py
import ipdb
num_samples = 10000
n_neighbors = 30
dsname = 'ModelNet10'
datafield = 'view'
labelfield = 'labs'
#dsname = 'ModelNet30'; datafield = 'data'; labelfield = 'label';
data_h5filename = '%s_1view_train_torchfeed.h5' % dsname
feat_h5filename = '3000112_-%s_trainfeatures.h5' % dsname
if dsname == 'ModelNet10':
classnames = [
'bathtub',
'bed',
'chair',
'desk',
'dresser',
'monitor',
'night_stand',
'sofa',
'table',
'toilet'
]
else:
classnames = []
#ds = datasets.load_digits(n_class=6)
ds = namedtuple('dataset', ['data', 'target', 'images'])
h5file = h5py.File(data_h5filename, 'r')
ds.images = np.squeeze(h5file[datafield][:]).transpose(0, 2, 1) # Nx32x32
ds.target = h5file[labelfield][:].reshape(-1) # N
h5file.close()
h5file = h5py.File(feat_h5filename, 'r')
ds.data = h5file['feat_ip1'][:] # D=Nx256
lab = h5file['cls_labs'][:].reshape(-1) # N
assert(np.array_equal(lab, ds.target))
h5file.close()
# subsampling the data
ds.data = ds.data[:num_samples]
ds.target = ds.target[:num_samples].astype(int)
num_cls = np.max(ds.target)+1
ds.images = ds.images[:num_samples]
X = ds.data
y = ds.target
n_samples, n_features = X.shape
# ipdb.set_trace()
#----------------------------------------------------------------------
# Scale and visualize the embedding vectors
def plot_embedding(xy, title=None, show_images=True, show_points=True):
assert(show_images or show_points)
xy_min, xy_max = np.min(xy, 0), np.max(xy, 0)
xy = (xy - xy_min) / (xy_max - xy_min)
if show_points:
fig = plt.figure(figsize=(11, 6))
else:
fig = plt.figure(figsize=(12, 6))
ax = plt.subplot(111)
if show_points:
seen_cls = np.zeros(num_cls, dtype='bool')
for i in range(xy.shape[0]):
cls = ds.target[i]
plt.text(xy[i, 0], xy[i, 1], str(cls+1),
color=plt.cm.tab10((cls) / 10.),
fontdict={'weight': 'bold', 'size': 9})
patches = []
for clsno in range(num_cls):
patches.append(
mpatches.Patch(
color=plt.cm.tab10(
(clsno) /
num_cls),
label=(
'%d: ' %
(clsno +
1)) +
classnames[clsno]))
lgd = ax.legend(handles=patches, fontsize=16, bbox_to_anchor=(1.3, 1))
if show_images and hasattr(offsetbox, 'AnnotationBbox'):
# only print thumbnails with matplotlib > 1.0
shown_images = np.array([[1., 1.]]) # just something big
for i in range(ds.data.shape[0]):
dist = np.sum((xy[i] - shown_images) ** 2, 1)
if np.min(dist) < 2e-3:
# don't show points that are too close
continue
shown_images = np.r_[shown_images, [xy[i]]]
imagebox = offsetbox.AnnotationBbox(
offsetbox.OffsetImage(
np.pad(
ds.images[i],
1,
'constant'),
cmap=plt.cm.gray),
xy[i],
frameon=False)
ax.add_artist(imagebox)
lgd = None
plt.xticks([]), plt.yticks([])
if title is not None:
plt.title(title)
return [fig, lgd]
# print(fig.get_size_inches())
#----------------------------------------------------------------------
# Plot images of the ds
img_h = ds.images.shape[2]
pad = 4
n_img_per_row = min(20, int(np.floor(np.sqrt(n_samples))))
img = np.zeros(((img_h+pad) * n_img_per_row, (img_h+pad) * n_img_per_row))
for i in range(n_img_per_row):
ix = (img_h+pad) * i + pad/2
for j in range(n_img_per_row):
iy = (img_h+pad) * j + pad/2
img[int(ix):int(ix + img_h), int(iy):int(iy + img_h)
] = ds.images[int(i * n_img_per_row + j)]
plt.imshow(img, cmap=plt.cm.gray)
plt.xticks([])
plt.yticks([])
plt.title('A selection from the dataset')
#----------------------------------------------------------------------
# t-SNE embedding of the ds dataset
print("Computing t-SNE embedding")
tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
t0 = time()
X_tsne = tsne.fit_transform(X)
print("plotting")
fig1, lgd1 = plot_embedding(X_tsne,
title=None,
show_images=False,
show_points=True)
fig2, _ = plot_embedding(X_tsne,
title=None,
show_images=True,
show_points=False)
if n_samples > 0:
print("saving figures")
plt.figure(fig1.number)
plt.savefig(
'/home/dineshj/Documents/Drafts/oneshot/figs/%s_tsne_noims-%dsamples-%dnbrs-wide.pdf' %
(dsname, num_samples, n_neighbors), bbox_extra_artists=(
lgd1,), bbox_inches='tight')
plt.figure(fig2.number)
plt.savefig(
'/home/dineshj/Documents/Drafts/oneshot/figs/%s_tsne-%dsamples-%dnbrs-wide.pdf' %
(dsname, num_samples, n_neighbors), bbox_inches='tight')
print("%.2f s" % (time()-t0))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment