Skip to content

Instantly share code, notes, and snippets.

@ay27
Created July 31, 2019 02:48
Show Gist options
  • Save ay27/d83a1e0e9aa2aca312dbaf08caf16a29 to your computer and use it in GitHub Desktop.
Save ay27/d83a1e0e9aa2aca312dbaf08caf16a29 to your computer and use it in GitHub Desktop.
[Fancy Logger] my logger
import datetime
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
import numpy as np
import time
try:
import scipy.misc
except ImportError:
scipy = None
import os
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
def embedding_logger(tensor, save_path, meta_data=None):
embedding_var = tf.Variable(tensor)
os.makedirs(save_path, exist_ok=True)
meta_path = None
if meta_data:
meta_path = os.path.join(save_path, 'meta.csv')
with open(meta_path, 'w') as f:
for w in meta_data:
f.write(w)
f.write('\n')
with tf.Session() as sess:
writer = tf.summary.FileWriter(save_path, sess.graph)
sess.run(embedding_var.initializer)
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name
embedding.metadata_path = meta_path
projector.visualize_embeddings(writer, config)
saver_embed = tf.train.Saver([embedding_var])
saver_embed.save(sess, os.path.join(save_path, 'embedding.ckpt'), 1)
writer.close()
class BoardLogger(object):
def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
def image_summary(self, tag, images, step):
"""Log a list of images."""
if scipy is None:
print('not scipy found, skip image_summary')
return
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values ** 2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()
class Target(object):
def _update(self):
raise NotImplementedError
class Schedule(object):
def __init__(self):
self._step = 0
self._t = []
self._intervals = []
def add_schedule(self, target, interval=1):
if isinstance(target, list):
self._t.extend(target)
self._intervals.extend([interval] * len(target))
else:
self._t.append(target)
self._intervals.append(interval)
return self
def ticktock(self):
self._step += 1
for t, val in zip(self._t, self._intervals):
if (isinstance(val, list) and self._step in val) \
or (self._step % val == 0):
if isinstance(t, Target):
t._update()
else:
t()
class ValueTarget(Target):
def __init__(self, value):
super().__init__()
self.value = value
self._old_value = None
def _update(self):
if self._old_value is not None:
self._update_func(self._old_value, self.value)
self._old_value = self.value
def reset(self, value=0.0):
self.value = value
self._reset(value)
def _update_func(self, old_value, new_value):
raise NotImplementedError
def _reset(self, value):
raise NotImplementedError
class MovingMean(ValueTarget):
def __init__(self, mean_steps=100):
super().__init__(0)
self._mean_steps = float(mean_steps)
self._fifo = []
self._sum = 0.0
def _update_func(self, old_value, new_value):
if len(self._fifo) == self._mean_steps:
self._sum = self._sum - self._fifo.pop(0) + new_value
self.value = self._sum / self._mean_steps
self._fifo.append(new_value)
else:
self._fifo.append(new_value)
self._sum += new_value
self.value = self._sum / float(len(self._fifo))
def _reset(self, value):
self._sum = value
self._fifo = []
class Average(ValueTarget):
def __init__(self):
super().__init__(0)
self._reset(0)
def _update_func(self, old_value, new_value):
self._sum += new_value
self._iters += 1
self.value = self._sum / float(self._iters)
def _reset(self, value):
self._sum = value
self._iters = 0
class TimeStamp(ValueTarget):
def __init__(self):
super().__init__(time.time())
pass
def _update_func(self, old_value, new_value):
self.value = time.time()
def _reset(self, value):
self.value = time.time()
class CsvLogger(Target):
def __init__(self, output_file, name_target_pairs, flush_interval=100, append_time=False):
super().__init__()
if len(os.path.dirname(output_file)) > 0 and not os.path.exists(os.path.dirname(output_file)):
os.mkdir(os.path.dirname(output_file))
if append_time:
time_now = datetime.datetime.now().strftime("d%d-H%H-M%M-S%S")
self._output_file = output_file + time_now
self._output_file = output_file
self._name_target_pairs = name_target_pairs
for _, t in self._name_target_pairs.items():
assert isinstance(t, ValueTarget)
self._flush_interval = flush_interval
self._cache = []
self._file = open(output_file, 'w')
self._header = self._name_target_pairs.keys()
self._file.write(','.join(self._header) + '\n')
def _update(self):
tmp_dat = []
for ii, h in enumerate(self._header):
tmp_dat.append(str(self._name_target_pairs[h].value))
self._cache.append(','.join(tmp_dat) + '\n')
if len(self._cache) == self._flush_interval:
self._file.writelines(self._cache)
self._cache = []
def __del__(self):
if len(self._cache) > 0:
self._file.writelines(self._cache)
self._file.close()
if __name__ == '__main__':
s = Schedule()
avg = Average()
moving = MovingMean(10)
ts = TimeStamp()
log = CsvLogger('/tmp/test', {'ts': ts, 'avg': avg, 'moving': moving}, flush_interval=10)
s.add_schedule([ts, avg, moving, log])
x = np.random.rand(1000)
for v in range(103):
print(v)
avg.value = x[v]
moving.value = x[v]
s.ticktock()
avg.reset()
moving.reset()
s.ticktock()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment