Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
A generic tensorboard logger for scalars and histograms or distributions
import tensorflow as tf
import numpy as np
class TensorBoardLogger(object):
"""
Log scalar and histograms/distributions to tensorboard.
Usage:
```
logger = TensorBoardLogger(log_dir = '/tmp/test')
for i in range(10):
logger.log(
logs=dict(
float_test=np.random.random(),
int_test=np.random.randint(0,4),
),
histograms=dict(
actions=np.random.randint(0,3,size=np.random.randint(5,20))
)
)
```
Ref: https://github.com/fchollet/keras/blob/master/keras/callbacks.py
Url: https://gist.github.com/wassname/b692f8e8686655011618dfbe8d8a9e3f
"""
def __init__(self, log_dir, session=None):
self.log_dir = log_dir
self.writer = tf.summary.FileWriter(self.log_dir)
self.episode = 0
print('TensorBoardLogger started. Run `tensorboard --logdir={}` to visualize'.format(self.log_dir))
self.histograms = {}
self.histogram_inputs = {}
self.session = session or tf.get_default_session() or tf.Session()
def log(self, logs={}, histograms={}):
# scalar logging
for name, value in logs.items():
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, self.episode)
# histograms
for name, value in histograms.items():
if name not in self.histograms:
# make a tensor with no fixed shape
self.histogram_inputs[name] = tf.Variable(value,validate_shape=False)
self.histograms[name] = tf.summary.histogram(name, self.histogram_inputs[name])
input_tensor = self.histogram_inputs[name]
summary = self.histograms[name]
summary_str = summary.eval(session=self.session, feed_dict={input_tensor.name:value})
self.writer.add_summary(summary_str, self.episode)
self.writer.flush()
self.episode += 1
# # Test
log_dir = '/tmp/test'
logger = TensorBoardLogger(log_dir)
logger.log(logs=dict(
float_test=1.1,
int_test=1,
),
histograms=dict(actions=[1,2,3])
)
for i in range(10):
logger.log(logs=dict(
float_test=np.random.random(),
int_test=np.random.randint(0,4),
),
histograms=dict(actions=np.random.randint(0,3,size=np.random.randint(5,20)))
)
# Make sure we can read the written messages
import glob, os
from tensorflow.python.summary import summary_iterator
event_paths = glob.glob(os.path.join(log_dir, "event*"))
# If the tests runs multiple time in the same directory we can have
# more than one matching event file. We only want to read the last one.
event_reader = summary_iterator.summary_iterator(event_paths[-1])
# Skip over the version event.
next(event_reader)
events = (list(event_reader))
events0 = [e for e in events if e.step==0]
# Make sure the first messages have the value we expect
event = events0[0]
assert event.step==0
values = list(event.summary.value)
assert values[0].tag=='float_test'
np.testing.assert_almost_equal(values[0].simple_value,1.1,1)
event = events0[1]
assert event.step==0
values = list(event.summary.value)
assert values[0].tag=='int_test'
np.testing.assert_almost_equal(values[0].simple_value,1,1)
event = events0[2]
assert event.step==0
values = list(event.summary.value)
assert 'actions_' in values[0].tag
assert values[0].histo.sum==17.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.