Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save lidopypy/b53b2f9757357057d389c92945612e88 to your computer and use it in GitHub Desktop.
Save lidopypy/b53b2f9757357057d389c92945612e88 to your computer and use it in GitHub Desktop.
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
#數據預處理
test_data = np.array(pd.read_csv(r'C:\\Users\\lido_lee\\Downloads\\fasion_mnist\\fashion-mnist_test.csv'), dtype='float32')
embed_count = 1600
x_test = test_data[:embed_count, 1:] / 255
y_test = test_data[:embed_count, 0]
#建立資料夾,可以將callback儲存的log丟進來
logdir = 'C:\\Users\\lido_lee\\Downloads\\fmnist_callbacks'
# setup the write and embedding tensor
summary_writer = tf.summary.FileWriter(logdir)
embedding_var = tf.Variable(x_test, name='fmnist_embedding')
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name
embedding.metadata_path = os.path.join(logdir, 'metadata.tsv')
embedding.sprite.image_path = os.path.join(logdir, 'sprite.png')
embedding.sprite.single_image_dim.extend([28, 28])
projector.visualize_embeddings(summary_writer, config)
# run the sesion to create the model check point
with tf.Session() as sesh:
sesh.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sesh, os.path.join(logdir, 'model.ckpt'))
# create the sprite image and the metadata file
rows = 28
cols = 28
label = ['t_shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boot']
sprite_dim = int(np.sqrt(x_test.shape[0]))
sprite_image = np.ones((cols * sprite_dim, rows * sprite_dim))
index = 0
labels = []
for i in range(sprite_dim):
for j in range(sprite_dim):
labels.append(label[int(y_test[index])])
sprite_image[
i * cols: (i + 1) * cols,
j * rows: (j + 1) * rows
] = x_test[index].reshape(28, 28) * -1 + 1
index += 1
with open(embedding.metadata_path, 'w') as meta:
meta.write('Index\tLabel\n')
for index, label in enumerate(labels):
meta.write('{}\t{}\n'.format(index, label))
plt.imsave(embedding.sprite.image_path, sprite_image, cmap='gray')
plt.imshow(sprite_image, cmap='gray')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment