Skip to content

Instantly share code, notes, and snippets.

@lidopypy
Created February 18, 2019 05:35
Show Gist options
  • Save lidopypy/fce5fa45452c58fb99e3f47265625015 to your computer and use it in GitHub Desktop.
Save lidopypy/fce5fa45452c58fb99e3f47265625015 to your computer and use it in GitHub Desktop.
import os
from keras.datasets import mnist
import matplotlib.pyplot as plt
from keras import backend as K
import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
batch_size = 128
num_classes = 10
epochs = 12
#建立資料夾,可以將callback儲存的log丟進來
log_dir = 'C:\\Users\\Lido_Lee\\Downloads\\mnist_callbacks'
# input image dimensions
img_rows, img_cols = 28, 28
# 直接從 Keras data 庫讀取 MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#判斷 Keras 後端讀取資料格式
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
#數據預處理,轉格float32格式,且值在0~1之間
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_test = x_test.reshape((10000,28*28))
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
embed_count = 1600
x_test = x_test[:embed_count] / 255
y_test = y_test[:embed_count]
# setup the write and embedding tensor
summary_writer = tf.summary.FileWriter(log_dir)
embedding_var = tf.Variable(x_test, name='mnist_embedding')
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name
embedding.metadata_path = os.path.join(log_dir, 'metadata.tsv')
embedding.sprite.image_path = os.path.join(log_dir, 'sprite.png')
embedding.sprite.single_image_dim.extend([28, 28])
projector.visualize_embeddings(summary_writer, config)
# run the sesion to create the model check point
with tf.Session() as sesh:
sesh.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sesh, os.path.join(log_dir, 'model.ckpt'))
# create the sprite image and the metadata file
rows = 28
cols = 28
label = ['0', '1', '2', '3', '4',
'5', '6', '7', '8', '9']
sprite_dim = int(np.sqrt(x_test.shape[0]))
sprite_image = np.ones((cols * sprite_dim, rows * sprite_dim))
index = 0
labels = []
for i in range(sprite_dim):
for j in range(sprite_dim):
labels.append(label[int(y_test[index])])
sprite_image[
i * cols: (i + 1) * cols,
j * rows: (j + 1) * rows
] = x_test[index].reshape(28, 28) * -1 + 1
index += 1
with open(embedding.metadata_path, 'w') as meta:
meta.write('Index\tLabel\n')
for index, label in enumerate(labels):
meta.write('{}\t{}\n'.format(index, label))
plt.imsave(embedding.sprite.image_path, sprite_image, cmap='gray')
plt.imshow(sprite_image, cmap='gray')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment