Skip to content

Instantly share code, notes, and snippets.

@BassyKuo
Created May 5, 2018 18:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BassyKuo/44a61e9078f211d5fd9ad8ab0b0bee60 to your computer and use it in GitHub Desktop.
Save BassyKuo/44a61e9078f211d5fd9ad8ab0b0bee60 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import tensorflow as tf # tf.VERSION == 1.4.0
from tensorflow.contrib.tensorboard.plugins import projector
# ... write something here ...
dx, x_embedding = discriminator(x, is_training=True)
dz, z_embedding = discriminator(z, is_training=True)
embedding_tensor = tf.concat([x_embedding, z_embedding], axis=0)
embedding_var = tf.Variable(tf.zeors_like(embedding_tensor), name='embedding')
assign_embedding = tf.assign(embedding_var, embedding_tensor)
writer = tf.summary.FileWriter(save_folder)
# ===[ Create tf.summary
# ...
# ===[ Create projector
vis_config = projector.ProjectorConfig()
for e in [embedding_var]:
embedding = vis_config.embeddings.add()
embedding.tensor_name = e.name
# Link this tensor to its metadata file (e.g. labels).
embedding.metadata_path = e.op.name+'.tsv'
embedding.sprite.image_path = e.op.name+'.png'
embedding.sprite.single_image_dim.extend([img_len, img_len])
projector.visualize_embeddings(writer, vis_config)
# ===[ Create label tsv (metadata)
with open(os.path.join(save_folder, embedding_var.op.name+'.tsv'), 'w') as f:
f.write("Index\tLabel\n")
idx = 0
for _ in range(x_embedding.shape.as_list()[0]):
f.write("%d\t%d\n" % (idx,1))
idx += 1
for _ in range(z_embedding.shape.as_list()[0]):
f.write("%d\t%d\n" % (idx,0))
idx += 1
# ===[ Session
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver([embedding_var])
# ... processing ...
sess.run(assign_embedding)
_embedding_map_to_img = np.concatenate(sess.run([x,z]), axis=0)
save_visualization(_embedding_map_to_img, save_path=os.path.join(save_folder, embedding_var.op.name+'.png'))
saver.save(sess, save_path=os.path.join(save_folder, 'model.ckpt'), global_step=global_step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment