Skip to content

Instantly share code, notes, and snippets.

@wangg12
Forked from mrdrozdov/example.py
Created February 21, 2017 15:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wangg12/cafac6d5e6a757b7229277e60d0fa165 to your computer and use it in GitHub Desktop.
Save wangg12/cafac6d5e6a757b7229277e60d0fa165 to your computer and use it in GitHub Desktop.
Logging in Tensorflow
from tf_logger import TFLogger
""" Example of using TFLogger to save train & dev statistics. To visualize
in tensorboard simply do:
tensorboard --logdir /path/to/summaries
This code does depend on Tensorflow, but does not require that your model
is built using Tensorflow. For instance, could build a model in Chainer, then
log the loss and accuracy from your Chainer model using TFLogger.
"""
train_tf_logger = TFLogger(os.path.join('.', 'summaries', 'train'))
eval_tf_logger = TFLogger(os.path.join('.', 'summaries', 'eval'))
for step, (x_batch, y_batch) in enumerate(batch_iterator):
acc, loss = model.train(x_batch, y_batch)
train_tf_logger.log(step=step, accuracy=acc, loss=loss)
if step % eval_step == 0:
acc, loss = evalute(model)
eval_tf_logger.log(step=step, accuracy=acc, loss=loss)
import tensorflow as tf
class TFLogger(object):
""" Creates an "empty model" that writes Tensorflow summaries. Can
visualize these summaries with Tensorboard.
"""
def __init__(self, summary_dir):
super(TFLogger, self).__init__()
self.summary_dir = summary_dir
self.__initialize()
def __initialize(self):
sess = tf.Session()
loss = tf.Variable(0.0, name="loss", trainable=False)
acc = tf.Variable(0.0, name="accuracy", trainable=False)
loss_summary = tf.scalar_summary("loss", loss)
acc_summary = tf.scalar_summary("accuracy", acc)
summary_op = tf.merge_summary([loss_summary, acc_summary])
summary_writer = tf.train.SummaryWriter(self.summary_dir, sess.graph)
saver = tf.train.Saver(tf.all_variables())
sess.run(tf.initialize_all_variables())
self.sess = sess
self.summary_op = summary_op
self.summary_writer = summary_writer
self.loss = loss
self.acc = acc
def log(self, step, loss, accuracy):
feed_dict = {
self.loss: loss,
self.acc: accuracy,
}
# sess.run returns a list, so we have to explicitly
# extract the first item using sess.run(...)[0]
summaries = self.sess.run([self.summary_op], feed_dict)[0]
self.summary_writer.add_summary(summaries, step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment