Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active December 6, 2018 12:54
Show Gist options
  • Save vadimkantorov/1ff9ceb82c0d06365f0a21d44e870d98 to your computer and use it in GitHub Desktop.
Save vadimkantorov/1ff9ceb82c0d06365f0a21d44e870d98 to your computer and use it in GitHub Desktop.
Example of interactive SVG scatter plot (with image thumbnails) produced by running t-SNE on MNIST
import base64
import random
import cv2
import torch
import torchvision
def svg(points, labels, thumbnails, legend_size = 1e-1, legend_font_size = 5e-2, circle_radius = 5e-3):
points = (points - points.min(0)[0]) / (points.max(0)[0] - points.min(0)[0])
class_index = sorted(set(labels))
class_colors = [360.0 * i / len(class_index) for i in range(len(class_index))]
colors = [class_colors[class_index.index(label)] for label in labels]
thumbnails_base64 = [base64.b64encode(cv2.imencode('.jpg', img.mul(255).permute(1, 2, 0).numpy()[..., ::-1])[1]) for img in thumbnails]
return '<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1">' + \
''.join(map('''<circle cx="{}" cy="{}" title="{}" fill="hsl({}, 50%, 50%)" r="{}" desc="data:image/jpeg;base64,{}" onmouseover="evt.target.ownerDocument.getElementById('preview').setAttribute('href', evt.target.getAttribute('desc')); evt.target.ownerDocument.getElementById('label').textContent = evt.target.getAttribute('title');" />'''.format, points[:, 0], points[:, 1], labels, colors, [circle_radius] * len(points), thumbnails_base64)) + \
'''<image id="preview" x="0" y="{legend_size}" width="{legend_size}" height="{legend_size}" />
<text id="label" x="0" y="{legend_size}" font-size="{legend_font_size}" />
</svg>'''.format(legend_size = legend_size, legend_font_size = legend_font_size)
#def scatter(points, labels, thumbnails = None, legend_size = 1e-1, legend_font_size = 5e-2, circle_radius = 5e-3, size = 512):
# points = (points - points.min(0)[0]) / (points.max(0)[0] - points.min(0)[0])
# class_index = sorted(set(labels))
# class_colors = [360.0 * i / len(class_index) for i in range(len(class_index))]
# colors = [class_colors[class_index.index(label)] for label in labels]
# thumbnails_base64 = [base64.b64encode(cv2.imencode('.jpg', np.asarray(img)[..., ::-1])[1]).decode('ascii') for img in thumbnails] if thumbnails is not None else [''] * len(points)
# return '<svg viewBox="0 0 1 1" width="{0}" height="{0}">'.format(size) + \
# ''.join(map('''<circle cx="{}" cy="{}" title="{}" fill="hsl({}, 50%, 50%)" r="{}" desc="data:image/jpeg;base64,{}" onmouseover="evt.target.ownerDocument.getElementById('preview').setAttribute('href', evt.target.getAttribute('desc')); evt.target.ownerDocument.getElementById('label').textContent = evt.target.getAttribute('title');" />'''.format, points[:, 0], points[:, 1], labels, colors, [circle_radius] * len(points), thumbnails_base64)) + \
# '''<image id="preview" x="0" y="{legend_size}" width="{legend_size}" height="{legend_size}" />
# <text id="label" x="0" y="{legend_size}" font-size="{legend_font_size}" />
# </svg>'''.format(legend_size = legend_size, legend_font_size = legend_font_size)
def tsne(featuers):
# TODO: reimplement following https://github.com/cemoody/topicsne/blob/master/tsne.py and https://nlml.github.io/in-raw-numpy/in-raw-numpy-t-sne/
import sklearn.manifold
return torch.from_numpy(sklearn.manifold.TSNE().fit_transform(features.numpy()))
dataset = torchvision.datasets.MNIST('./MNIST', train = True, download = True, transform = torchvision.transforms.ToTensor())
thumbnails, labels = zip(*dataset)
features = torch.cat(thumbnails).view(len(thumbnails), -1)
features, labels, thumbnails = zip(*random.sample(zip(features, labels, thumbnails), k = 10000))
features = torch.stack(features)
embeddings = tsne(features)
open('tsne.svg', 'w').write(svg(embeddings, labels, thumbnails))
@sagarchaturvedi1
Copy link

sagarchaturvedi1 commented Aug 6, 2018

Hi, I tried running this code and the images are not displayed when the svg file is opened in the browser. I have attached a snapshot here.
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment