Skip to content

Instantly share code, notes, and snippets.

@daa233
Last active May 3, 2019 07:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save daa233/3b9ce5d34709326418203183951704a7 to your computer and use it in GitHub Desktop.
Save daa233/3b9ce5d34709326418203183951704a7 to your computer and use it in GitHub Desktop.
import numpy as np
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
def draw_tsne(X, y, save_fig=None, show=False):
"""
T-SNE visualization by `sklearn.manifold.TSNE`.
Reference: https://www.scipy-lectures.org/packages/scikit-learn/auto_examples/plot_tsne.html
:param X: data to be projected
:param y: data labels
:param save_fig: the path to save the fig
:return: None
"""
assert len(X) == len(y)
t_sne = TSNE(n_components=2, random_state=0, verbose=1)
X_2d = t_sne.fit_transform(X)
num_ids = len(np.unique(y))
target_ids = range(num_ids)
plt.figure(figsize=(6, 5))
# distinct color map, https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
cmap = plt.cm.get_cmap('hsv', num_ids+1)
for i in target_ids:
plt.scatter(X_2d[y == i, 0], X_2d[y == i, 1], c=cmap(i), s=3, label=str(i))
plt.legend()
if save_fig:
plt.savefig(save_fig)
if show:
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment