-
-
Save MuLx10/080e5097662a54fc3e29b57bbe703ad8 to your computer and use it in GitHub Desktop.
A generic tensorboard logger for scalars and histograms or distributions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
class TensorBoardLogger(object): | |
""" | |
Log scalar and histograms/distributions to tensorboard. | |
Usage: | |
``` | |
logger = TensorBoardLogger(log_dir = '/tmp/test') | |
for i in range(10): | |
logger.log( | |
logs=dict( | |
float_test=np.random.random(), | |
int_test=np.random.randint(0,4), | |
), | |
histograms=dict( | |
actions=np.random.randint(0,3,size=np.random.randint(5,20)) | |
) | |
) | |
``` | |
Ref: https://github.com/fchollet/keras/blob/master/keras/callbacks.py | |
Url: https://gist.github.com/wassname/b692f8e8686655011618dfbe8d8a9e3f | |
""" | |
def __init__(self, log_dir, session=None): | |
self.log_dir = log_dir | |
self.writer = tf.summary.FileWriter(self.log_dir) | |
self.episode = 0 | |
print('TensorBoardLogger started. Run `tensorboard --logdir={}` to visualize'.format(self.log_dir)) | |
self.histograms = {} | |
self.histogram_inputs = {} | |
self.session = session or tf.get_default_session() or tf.Session() | |
def log(self, logs={}, histograms={}): | |
# scalar logging | |
for name, value in logs.items(): | |
summary = tf.Summary() | |
summary_value = summary.value.add() | |
summary_value.simple_value = value | |
summary_value.tag = name | |
self.writer.add_summary(summary, self.episode) | |
# histograms | |
for name, value in histograms.items(): | |
if name not in self.histograms: | |
# make a tensor with no fixed shape | |
self.histogram_inputs[name] = tf.Variable(value,validate_shape=False) | |
self.histograms[name] = tf.summary.histogram(name, self.histogram_inputs[name]) | |
input_tensor = self.histogram_inputs[name] | |
summary = self.histograms[name] | |
summary_str = summary.eval(session=self.session, feed_dict={input_tensor.name:value}) | |
self.writer.add_summary(summary_str, self.episode) | |
self.writer.flush() | |
self.episode += 1 | |
# # Test | |
log_dir = '/tmp/test' | |
logger = TensorBoardLogger(log_dir) | |
logger.log(logs=dict( | |
float_test=1.1, | |
int_test=1, | |
), | |
histograms=dict(actions=[1,2,3]) | |
) | |
for i in range(10): | |
logger.log(logs=dict( | |
float_test=np.random.random(), | |
int_test=np.random.randint(0,4), | |
), | |
histograms=dict(actions=np.random.randint(0,3,size=np.random.randint(5,20))) | |
) | |
# Make sure we can read the written messages | |
import glob, os | |
from tensorflow.python.summary import summary_iterator | |
event_paths = glob.glob(os.path.join(log_dir, "event*")) | |
# If the tests runs multiple time in the same directory we can have | |
# more than one matching event file. We only want to read the last one. | |
event_reader = summary_iterator.summary_iterator(event_paths[-1]) | |
# Skip over the version event. | |
next(event_reader) | |
events = (list(event_reader)) | |
events0 = [e for e in events if e.step==0] | |
# Make sure the first messages have the value we expect | |
event = events0[0] | |
assert event.step==0 | |
values = list(event.summary.value) | |
assert values[0].tag=='float_test' | |
np.testing.assert_almost_equal(values[0].simple_value,1.1,1) | |
event = events0[1] | |
assert event.step==0 | |
values = list(event.summary.value) | |
assert values[0].tag=='int_test' | |
np.testing.assert_almost_equal(values[0].simple_value,1,1) | |
event = events0[2] | |
assert event.step==0 | |
values = list(event.summary.value) | |
assert 'actions_' in values[0].tag | |
assert values[0].histo.sum==17.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment