Skip to content

Instantly share code, notes, and snippets.

@techbless
Created October 11, 2023 06:28
Show Gist options
  • Save techbless/7dff78302ad36efc624e0b23289eb5e6 to your computer and use it in GitHub Desktop.
Save techbless/7dff78302ad36efc624e0b23289eb5e6 to your computer and use it in GitHub Desktop.
T-SNE mnist code
import keras
from keras.datasets import mnist
from sklearn import preprocessing
import numpy as np
(train_xs, train_ys), (test_xs, test_ys) = mnist.load_data()
dim_x = train_xs.shape[1] * train_xs.shape[2]
dim_y = 10
train_xs = train_xs.reshape(train_xs.shape[0], dim_x).astype(np.float32)
scaler = preprocessing.MinMaxScaler().fit(train_xs)
train_xs = scaler.transform(train_xs)
print(train_xs.shape)
print(train_ys.shape)
ridx = np.random.randint(train_xs.shape[0], size=10000)
np_train_xs = train_xs[ridx, :]
np_train_ys = train_ys[ridx]
print(np_train_xs.shape)
print(np_train_ys.shape)
import sklearn
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_style('darkgrid')
sns.set_palette('muted')
sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
def draw_scatter(x, n_class, colors):
sns.palplot(sns.color_palette())
palette = np.array(sns.color_palette())
f = plt.figure(figsize=(14,14))
ax = plt.subplot(aspect='equal')
sc = ax.scatter(x[:,0], x[:,1], lw=0, s=540, c=palette[colors.astype(np.int)], alpha=0.2)
plt.xlim(-25, 25)
plt.ylim(-25, 25)
ax.axis('off')
ax.axis('tight')
plt.show()
tsne_train_xs = TSNE(random_state=42).fit_transform(np_train_xs)
draw_scatter(tsne_train_xs, dim_y, np_train_ys)
ridx = np.random.randint(train_xs.shape[0], size=10000)
np_train_xs = train_xs[ridx, :]
np_train_ys = train_ys[ridx]
print(np_train_xs.shape)
print(np_train_ys.shape)
ridx = np.random.randint(train_xs.shape[0], size = 1000) #data 크기를 줄임
np_train_xs = train_xs[ridx, :]
np_train_ys = train_ys[ridx]
sns.palplot(sns.color_palette()) # 숫자 0~9를 Color로 표시하여 보여줌
palette = np.array(sns.color_palette())
# 화면 구성은 3x3으로 보여줌
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(15,15))
for ax, perplexity in zip(axs.flat, [2,5,10,20,30,50,75,100,750]):
tsne_out = TSNE(n_components = 2, perplexity = perplexity).fit_transform(np_train_xs)
title = 'Perpelexity = {}'.format(perplexity)
ax.set_title(title)
ax.scatter(tsne_out[:,0], tsne_out[:,1], lw=0, s=25, c=palette[np_train_ys.astype(np.int)], alpha=0.3)
ax.axis('tight')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment