Skip to content

Instantly share code, notes, and snippets.

@hidez8891
Last active December 8, 2018 11:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hidez8891/8a2b192d5194fe3fb165cb1b0e2536c0 to your computer and use it in GitHub Desktop.
Save hidez8891/8a2b192d5194fe3fb165cb1b0e2536c0 to your computer and use it in GitHub Desktop.
Progressbar & loss image & sample image extension for chainer
from fastprogress import master_bar, progress_bar
from chainer import reporter
from math import ceil
import six
import PIL
from io import BytesIO
from IPython import display
class Progress(extension.Extension):
def __init__(self, update_interval=100, keys=[], img_fn=None):
self._update_interval = update_interval
self.total_bar = None
self.total_itr = None
self.loop_bar = None
self.loop_itr = None
self.keys = keys
self.img_fn = img_fn
self.graphs = {k: [] for k in self.keys}
self.graphs["_iter"] = []
self.next_interval = 0
self.summary = reporter.DictSummary()
def __call__(self, trainer):
training_length = trainer.stop_trigger.get_training_length()
length, unit = training_length
if self.total_bar is None:
self.total_bar = master_bar(range(ceil(length / self._update_interval)))
self.total_itr = iter(self.total_bar)
self.total_bar.names = self.keys
keys = [k for k in self.keys if k in trainer.observation]
self.summary.add({k: trainer.observation[k] for k in keys})
iteration = trainer.updater.iteration
if iteration >= self.next_interval:
self.next_interval += self._update_interval
# progress bar
try:
next(self.total_itr)
except:
pass
# loss graph
self.loop_bar = progress_bar(range(self._update_interval), parent=self.total_bar)
self.loop_itr = iter(self.loop_bar)
stats = self.summary.compute_mean()
stats_cpu = {}
for name, value in six.iteritems(stats):
stats_cpu[name] = float(value) # copy to CPU
self.summary = reporter.DictSummary()
graphs = []
self.graphs["_iter"].append(iteration)
for k in keys:
self.graphs[k].append(stats_cpu[k])
graphs.append([self.graphs["_iter"], self.graphs[k]])
self.total_bar.update_graph(graphs)
# sample image
if self.img_fn is not None:
img = self.img_fn()
buf = BytesIO()
PIL.Image.fromarray(img).save(buf, "png")
if not hasattr(self, 'outImg'):
self.outImg = display.display(display.Image(data=buf.getvalue()), display_id=1)
else:
self.outImg.update(display.Image(data=buf.getvalue()))
next(self.loop_itr)
def finalize(self):
if self.total_itr is not None:
try:
next(self.total_itr)
except:
pass
def make_image(generator):
def _gen():
return generator() # Numpy 2d image
return _gen
trainer.extend(Progress(update_interval=100, keys=["loss"], img_fn=make_image(gen)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment