Created
September 27, 2017 06:03
-
-
Save wassname/b692f8e8686655011618dfbe8d8a9e3f to your computer and use it in GitHub Desktop.
A generic tensorboard logger for scalars and histograms or distributions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
class TensorBoardLogger(object): | |
""" | |
Log scalar and histograms/distributions to tensorboard. | |
Usage: | |
``` | |
logger = TensorBoardLogger(log_dir = '/tmp/test') | |
for i in range(10): | |
logger.log( | |
logs=dict( | |
float_test=np.random.random(), | |
int_test=np.random.randint(0,4), | |
), | |
histograms=dict( | |
actions=np.random.randint(0,3,size=np.random.randint(5,20)) | |
) | |
) | |
``` | |
Ref: https://github.com/fchollet/keras/blob/master/keras/callbacks.py | |
Url: https://gist.github.com/wassname/b692f8e8686655011618dfbe8d8a9e3f | |
""" | |
def __init__(self, log_dir, session=None): | |
self.log_dir = log_dir | |
self.writer = tf.summary.FileWriter(self.log_dir) | |
self.episode = 0 | |
print('TensorBoardLogger started. Run `tensorboard --logdir={}` to visualize'.format(self.log_dir)) | |
self.histograms = {} | |
self.histogram_inputs = {} | |
self.session = session or tf.get_default_session() or tf.Session() | |
def log(self, logs={}, histograms={}): | |
# scalar logging | |
for name, value in logs.items(): | |
summary = tf.Summary() | |
summary_value = summary.value.add() | |
summary_value.simple_value = value | |
summary_value.tag = name | |
self.writer.add_summary(summary, self.episode) | |
# histograms | |
for name, value in histograms.items(): | |
if name not in self.histograms: | |
# make a tensor with no fixed shape | |
self.histogram_inputs[name] = tf.Variable(value,validate_shape=False) | |
self.histograms[name] = tf.summary.histogram(name, self.histogram_inputs[name]) | |
input_tensor = self.histogram_inputs[name] | |
summary = self.histograms[name] | |
summary_str = summary.eval(session=self.session, feed_dict={input_tensor.name:value}) | |
self.writer.add_summary(summary_str, self.episode) | |
self.writer.flush() | |
self.episode += 1 | |
# # Test | |
log_dir = '/tmp/test' | |
logger = TensorBoardLogger(log_dir) | |
logger.log(logs=dict( | |
float_test=1.1, | |
int_test=1, | |
), | |
histograms=dict(actions=[1,2,3]) | |
) | |
for i in range(10): | |
logger.log(logs=dict( | |
float_test=np.random.random(), | |
int_test=np.random.randint(0,4), | |
), | |
histograms=dict(actions=np.random.randint(0,3,size=np.random.randint(5,20))) | |
) | |
# Make sure we can read the written messages | |
import glob, os | |
from tensorflow.python.summary import summary_iterator | |
event_paths = glob.glob(os.path.join(log_dir, "event*")) | |
# If the tests runs multiple time in the same directory we can have | |
# more than one matching event file. We only want to read the last one. | |
event_reader = summary_iterator.summary_iterator(event_paths[-1]) | |
# Skip over the version event. | |
next(event_reader) | |
events = (list(event_reader)) | |
events0 = [e for e in events if e.step==0] | |
# Make sure the first messages have the value we expect | |
event = events0[0] | |
assert event.step==0 | |
values = list(event.summary.value) | |
assert values[0].tag=='float_test' | |
np.testing.assert_almost_equal(values[0].simple_value,1.1,1) | |
event = events0[1] | |
assert event.step==0 | |
values = list(event.summary.value) | |
assert values[0].tag=='int_test' | |
np.testing.assert_almost_equal(values[0].simple_value,1,1) | |
event = events0[2] | |
assert event.step==0 | |
values = list(event.summary.value) | |
assert 'actions_' in values[0].tag | |
assert values[0].histo.sum==17.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment