Skip to content

Instantly share code, notes, and snippets.

@MortisHuang
Last active February 21, 2019 13:05
Show Gist options
  • Save MortisHuang/ad1b6f995a46570f29dd58c008bcf25f to your computer and use it in GitHub Desktop.
Save MortisHuang/ad1b6f995a46570f29dd58c008bcf25f to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Created on Thu Feb 21 12:53:15 2019
@author: Mortis
"""
import umap
import pandas as pd
import numpy as np
from sklearn.datasets import load_digits
reducer = umap.UMAP()
digits = load_digits()
digits_df = pd.DataFrame(digits.data[:,:10])
digits_df['digit'] = pd.Series(digits.target).map(lambda x: 'Digit {}'.format(x))
reducer = umap.UMAP(random_state=42)
reducer.fit(digits.data)
embedding = reducer.transform(digits.data)
#%%
from io import BytesIO
from PIL import Image
import base64
def embeddable_image(data):
img_data = 255 - 15 * data.astype(np.uint8)
image = Image.fromarray(img_data, mode='L').resize((16, 16), Image.BICUBIC)
buffer = BytesIO()
image.save(buffer, format='png')
for_encoding = buffer.getvalue()
return 'data:image/png;base64,' + base64.b64encode(for_encoding).decode()
from bokeh.plotting import figure, show, output_file
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper
from bokeh.palettes import Spectral10
output_file("umap.html", title="UMAP MINST Example")
digits_df = pd.DataFrame(embedding, columns=('x', 'y'))
digits_df['digit'] = [str(x) for x in digits.target]
digits_df['image'] = list(map(embeddable_image, digits.images))
datasource = ColumnDataSource(digits_df)
color_mapping = CategoricalColorMapper(factors=[str(9 - x) for x in digits.target_names],
palette=Spectral10)
plot_figure = figure(
title='UMAP projection of the Digits dataset',
plot_width=600,
plot_height=600,
tools=('pan, wheel_zoom, reset')
)
plot_figure.add_tools(HoverTool(tooltips="""
<div>
<div>
<img src='@image' style='float: left; margin: 5px 5px 5px 5px'/>
</div>
<div>
<span style='font-size: 16px; color: #224499'>Digit:</span>
<span style='font-size: 18px'>@digit</span>
</div>
</div>
"""))
plot_figure.circle(
'x',
'y',
source=datasource,
color=dict(field='digit', transform=color_mapping),
line_alpha=0.6,
fill_alpha=0.6,
size=4
)
show(plot_figure)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment