Skip to content

Instantly share code, notes, and snippets.

@UrusuLambda
Created August 14, 2020 10:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save UrusuLambda/9fce2395fd987480811a4a5b804a1dc2 to your computer and use it in GitHub Desktop.
Save UrusuLambda/9fce2395fd987480811a4a5b804a1dc2 to your computer and use it in GitHub Desktop.
MNIST Data plot in interactive 3d for ipython notebook
from tensorflow.keras.datasets import mnist
from sklearn.manifold import TSNE
import pandas as pd
import plotly.express as px
#load and adjust mnist data | Mnistのデータをロード
(_, _), (x_test, y_test) = mnist.load_data()
_, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
_, x_test = x_train / 255., x_test / 255.
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1]*x_test.shape[2])
#Cut to 1000 , cause too slow to process all | 多すぎると遅いので1000個へ
x_test = x_test[:1000]
y_test = y_test[:1000]
#convert 784 -> 3 dimention | 784次元から3次元に次元圧縮(3次元表示のため)
x_test_transformed_3 = TSNE(n_components=3).fit_transform(x_test)
#Show 3D Plotting | ここで3次元のプロット
df=pd.DataFrame(x_test_transformed_3,columns=list("XYZ"))
df["label"]=np.array(y_test)
fig = px.scatter_3d(df, x='X', y='Y', z='Z', color='label', size_max=4, opacity=0.7)
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment