Last active
February 21, 2019 13:05
-
-
Save MortisHuang/ad1b6f995a46570f29dd58c008bcf25f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Thu Feb 21 12:53:15 2019 | |
@author: Mortis | |
""" | |
import umap | |
import pandas as pd | |
import numpy as np | |
from sklearn.datasets import load_digits | |
reducer = umap.UMAP() | |
digits = load_digits() | |
digits_df = pd.DataFrame(digits.data[:,:10]) | |
digits_df['digit'] = pd.Series(digits.target).map(lambda x: 'Digit {}'.format(x)) | |
reducer = umap.UMAP(random_state=42) | |
reducer.fit(digits.data) | |
embedding = reducer.transform(digits.data) | |
#%% | |
from io import BytesIO | |
from PIL import Image | |
import base64 | |
def embeddable_image(data): | |
img_data = 255 - 15 * data.astype(np.uint8) | |
image = Image.fromarray(img_data, mode='L').resize((16, 16), Image.BICUBIC) | |
buffer = BytesIO() | |
image.save(buffer, format='png') | |
for_encoding = buffer.getvalue() | |
return 'data:image/png;base64,' + base64.b64encode(for_encoding).decode() | |
from bokeh.plotting import figure, show, output_file | |
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper | |
from bokeh.palettes import Spectral10 | |
output_file("umap.html", title="UMAP MINST Example") | |
digits_df = pd.DataFrame(embedding, columns=('x', 'y')) | |
digits_df['digit'] = [str(x) for x in digits.target] | |
digits_df['image'] = list(map(embeddable_image, digits.images)) | |
datasource = ColumnDataSource(digits_df) | |
color_mapping = CategoricalColorMapper(factors=[str(9 - x) for x in digits.target_names], | |
palette=Spectral10) | |
plot_figure = figure( | |
title='UMAP projection of the Digits dataset', | |
plot_width=600, | |
plot_height=600, | |
tools=('pan, wheel_zoom, reset') | |
) | |
plot_figure.add_tools(HoverTool(tooltips=""" | |
<div> | |
<div> | |
<img src='@image' style='float: left; margin: 5px 5px 5px 5px'/> | |
</div> | |
<div> | |
<span style='font-size: 16px; color: #224499'>Digit:</span> | |
<span style='font-size: 18px'>@digit</span> | |
</div> | |
</div> | |
""")) | |
plot_figure.circle( | |
'x', | |
'y', | |
source=datasource, | |
color=dict(field='digit', transform=color_mapping), | |
line_alpha=0.6, | |
fill_alpha=0.6, | |
size=4 | |
) | |
show(plot_figure) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment