Created
November 22, 2018 07:19
-
-
Save Mainvooid/c413353bcf57546524d891a4d629a130 to your computer and use it in GitHub Desktop.
嵌入可视化 #TensorFlow #Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# 嵌入可视化 | |
import numpy as np | |
import tensorflow as tf | |
from tensorboard.plugins import projector | |
from tensorflow.examples.tutorials.mnist import input_data | |
import os | |
PATH_TO_MNIST_DATA = "MNIST_data" | |
LOG_DIR = "projector/data" | |
IMAGE_NUM = 300 | |
# 读取测试数据 | |
mnist = input_data.read_data_sets(PATH_TO_MNIST_DATA, one_hot=False) | |
# 获取数据 | |
plot_array = mnist.test.images[:IMAGE_NUM] # shape: (n_observations, n_features) | |
# 产生 meta data | |
np.savetxt(os.path.join(LOG_DIR, 'metadata.tsv'), mnist.test.labels[:IMAGE_NUM], fmt='%d') | |
# Download sprite image | |
# https://www.tensorflow.org/images/mnist_10k_sprite.png, 100x100 thumbnails | |
PATH_TO_SPRITE_IMAGE = os.path.join(LOG_DIR, 'mnist_10k_sprite.png') | |
# 可视化嵌入 | |
# 1) 创建2D张量变量 | |
session = tf.InteractiveSession() | |
embedding_var = tf.Variable(plot_array, name='embedding') | |
tf.global_variables_initializer().run() | |
# 2) 定期保存嵌入数据到LOG_DIR | |
# 这里我们只保存一次 设置global_step为固定数值 | |
saver = tf.train.Saver() | |
saver.save(session, os.path.join(LOG_DIR, "model.ckpt"), global_step=0) | |
# 3) 用嵌入链接meta数据和预览图片 | |
# 使用保存检查点文件的同样的LOG_DIR | |
summary_writer = tf.summary.FileWriter(LOG_DIR) | |
config = projector.ProjectorConfig() | |
# 可以添加多个embeddings. 这里我们只添加一个 | |
embedding = config.embeddings.add() | |
embedding.tensor_name = embedding_var.name | |
# 链接张量到metadata文件 (e.g. labels). | |
embedding.metadata_path = os.path.join(LOG_DIR, 'metadata.tsv') | |
# 链接张量到预览图. | |
embedding.sprite.image_path = PATH_TO_SPRITE_IMAGE | |
embedding.sprite.single_image_dim.extend([28, 28]) | |
# 保存配置文件以让tensorboard读取 | |
projector.visualize_embeddings(summary_writer, config) | |
""" | |
在.tsv的目录下输入 tensorboard --logdir=。 | |
或者 tensorboard --logdir=绝对路径 | |
来启动tb | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment