Skip to content

Instantly share code, notes, and snippets.

@rosinality
Last active April 8, 2017 14:22
Show Gist options
  • Save rosinality/94043851cbe0e127859fa073cba06df6 to your computer and use it in GitHub Desktop.
Save rosinality/94043851cbe0e127859fa073cba06df6 to your computer and use it in GitHub Desktop.
Simple live log plotting tool for Visdom
from vislog import Logger
from time import sleep
import numpy as np
import shutil
log = Logger('test')
brown1 = log.line('brown1')
brown2 = log.line('brown2')
image1 = log.image('image1')
x = 0
while True:
try:
proc1 = np.sin(x)
proc2 = np.cos(x)
proc3 = np.random.randn() * 0.1
brown1.log(proc1)
brown1.log(proc2, name='2')
brown2.log(proc3)
#image1.log(np.random.rand(3, 256, 128))
x += 0.02
sleep(1)
except KeyboardInterrupt:
break
import os
import pickle
from filelock import FileLock
from io import BytesIO
from visdom import Visdom
import numpy as np
import tarfile
from time import sleep
def wrap_numpy(x):
try:
if x.ndim < 1:
return x.reshape(-1)
return x
except AttributeError:
return np.array([x])
def squeeze(x):
if x.ndim > 1:
squ_x = np.squeeze(x)
if squ_x.ndim < 1:
return squ_x.reshape(-1)
return squ_x
return x
class VisLoader(object):
def __init__(self, filename):
self.viz = Visdom()
self.file = open(filename, 'rb')
self.plot_opts = {}
self.max_buffer = 1000
self.buffer = {}
self.plot = {}
self.X_pos = {}
self.pos = 0
self.initial_read()
def initial_read(self):
self.process()
def loop_process(self):
while True:
try:
self.process()
sleep(1)
except KeyError:
break
def process(self):
for log in self.read():
message = type(log).__name__
self.message_loop(message, log)
self.plot_buffer()
def message_loop(self, message, log):
if message == 'NewPlot':
if log.plot_name in self.plot:
return
self.plot[log.plot_name] = {'plot_type': log.plot_type,
'env': log.env,
'opts': log.opts,
'meta': log.meta,
'win': None}
print('Create', log)
elif message == 'PlotLine':
if log.X is None:
if log.plot_name not in self.X_pos:
self.X_pos[log.plot_name] = {}
try:
X_pos = self.X_pos[log.plot_name][log.name]
except KeyError:
self.X_pos[log.plot_name][log.name] = 0
X_pos = 0
Y = wrap_numpy(log.Y)
Y_len = len(Y)
X = np.arange(X_pos, X_pos + Y_len)
self.X_pos[log.plot_name][log.name] += Y_len
else:
X = wrap_numpy(log.X)
self.add_buffer(log.plot_name, X, Y, log.name, log.opts)
elif message == 'PlotImage':
plot_data = self.plot[log.plot_name]
with tarfile.TarFile(plot_data['meta']['filename']) as tar:
lastfile = tar.getnames()[-1]
imgfile = tar.extractfile(lastfile)
imgbyte = BytesIO(imgfile.read())
img = np.load(imgbyte)
if plot_data['win'] is None:
win = self.viz.image(img, env=plot_data['env'], opts=plot_data['opts'])
plot_data['win'] = win
else:
self.viz.image(img, win=plot_data['win'], env=plot_data['env'], opts=plot_data['opts'])
def add_buffer(self, plot_name, X, Y, name, opts):
try:
self.buffer[plot_name]['X'].append(X)
self.buffer[plot_name]['Y'].append(Y)
self.buffer[plot_name]['name'].append(name)
self.buffer[plot_name]['opts'].append(opts)
except KeyError:
self.buffer[plot_name] = {'X': [X], 'Y': [Y], 'name': [name], 'opts': [opts]}
def plot_buffer(self):
for plot_name, plot_data in self.plot.items():
plot_type = plot_data['plot_type']
if plot_type not in {'line', 'scatter'}:
continue
if plot_name not in self.buffer:
continue
buffer = self.buffer[plot_name]
X_buf, Y_buf = [], []
prev_name, prev_opts = None, None
#print(buffer['name'])
for X_e, Y_e, name, opts in zip(buffer['X'], buffer['Y'], buffer['name'], buffer['opts']):
if prev_opts != opts or prev_name != name:
self.create_or_update(plot_name, plot_type, X_buf, Y_buf, prev_name, prev_opts)
X_buf, Y_buf, prev_name, prev_opts = [X_e], [Y_e], name, opts
else:
X_buf.append(X_e)
Y_buf.append(Y_e)
self.create_or_update(plot_name, plot_type, X_buf, Y_buf, prev_name, prev_opts)
self.buffer = {}
def create_or_update(self, plot_name, plot_type, X_buf, Y_buf, name, opts):
plot_data = self.plot[plot_name]
if len(X_buf) < 1:
return
X = squeeze(np.vstack(X_buf))
Y = squeeze(np.vstack(Y_buf))
if plot_data['win'] is None:
if plot_type == 'line':
win = self.viz.line(Y, X, env=plot_data['env'], opts=plot_data['opts'])
plot_data['win'] = win
else:
self.viz.updateTrace(X, Y, win=plot_data['win'], name=name, env=plot_data['env'], opts=opts)
def read(self):
while True:
try:
yield pickle.load(self.file)
except EOFError:
self.pos = self.file.tell()
break
""" elif command == 'Plot':
args = list(log.args)
kwargs = log.kwargs
plot_name = log.name
plot_args = self.plot_register[plot_name]
plot_type = plot_args['plot']
plot_opts = plot_args['opts']
try:
plot_opts.update(kwargs['opts'])
except KeyError:
pass
if plot_type == 'image':
def plot_buffer(self):
self.buffer = {} """
if __name__ == '__main__':
vis = VisLoader('test')
vis.loop_process()
import os
import tarfile
import numpy as np
from filelock import FileLock
from tempfile import TemporaryFile
import pickle
from collections import namedtuple
NewPlot = namedtuple('NewPlot', 'plot_name plot_type env opts meta')
PlotLine = namedtuple('PlotLine', 'plot_name Y X name opts')
PlotScatter = namedtuple('PlotScatter', 'plot_name X Y opts')
PlotHistogram = namedtuple('PlotHistogram', 'plot_name X opts')
PlotImage = namedtuple('PlotImage', 'plot_name opts')
def validate_filename(filename):
invalid = '\/:*?"<>|'
trans = str.maketrans(invalid, '_' * len(invalid))
return filename.translate(trans)
class PickleWriter(object):
def __init__(self, logfile):
self.logfile = logfile
def write(self, value, flush=True):
pickle.dump(value, self.logfile)
if flush:
self.logfile.flush()
class BaseLogger(PickleWriter):
def __init__(self, logfile, plot):
super().__init__(logfile)
self.plot = plot
self.opts = plot.opts
def check_opt(prev, current):
opts_changed = None
if prev != current:
opts_changed = current
return opts_changed
class LineLogger(BaseLogger):
def log(self, Y, X=None, name=None, opts=None):
opts = check_opt(self.opts, opts)
self.write(PlotLine(self.plot.plot_name, Y, X, name, opts))
class ScatterLogger(BaseLogger):
def log(self, X, Y=None, opts=None):
pass
class HistogramLogger(BaseLogger):
def log(self, X, opts=None):
pass
class ImageLogger(BaseLogger):
def log(self, img, opts=None):
opts = check_opt(self.opts, opts)
if 'float' in str(img.dtype):
if img.max() <= 1:
img = img * 255
img = np.uint8(img)
tempfile = TemporaryFile()
np.save(tempfile, img)
with tarfile.open(self.plot.meta['filename'], 'a') as tar:
file_len = len(tar.getnames())
tarinfo = tarfile.TarInfo('{}.npy'.format(file_len))
tarinfo.size = tempfile.tell()
tempfile.seek(0)
tar.addfile(tarinfo, tempfile)
self.write(PlotImage(self.plot.plot_name, opts))
class Logger(object):
def __init__(self, filename):
self.file = open(filename, 'ab')
self.file_without_ext = \
os.path.splitext(os.path.split(filename)[1])[0]
self.log = PickleWriter(self.file)
self.plot = {}
self.plot_meta = {}
def line(self, plot_name, env=None, opts=None):
if opts is not None and 'title' not in opts:
opts['title'] = plot_name
record = NewPlot(plot_name, 'line', env, opts, None)
self.log.write(record, True)
return LineLogger(self.file, record)
def image(self, plot_name, env=None, opts=None):
if opts is not None and 'title' not in opts:
opts['title'] = plot_name
image_filename = \
validate_filename(
'{}-{}.tar'.format(self.file_without_ext, plot_name))
record = NewPlot(plot_name, 'image', env, opts, {'filename': image_filename})
self.log.write(record)
return ImageLogger(self.file, record)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment