Last active
December 8, 2018 11:54
-
-
Save hidez8891/8a2b192d5194fe3fb165cb1b0e2536c0 to your computer and use it in GitHub Desktop.
Progressbar & loss image & sample image extension for chainer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastprogress import master_bar, progress_bar | |
from chainer import reporter | |
from math import ceil | |
import six | |
import PIL | |
from io import BytesIO | |
from IPython import display | |
class Progress(extension.Extension): | |
def __init__(self, update_interval=100, keys=[], img_fn=None): | |
self._update_interval = update_interval | |
self.total_bar = None | |
self.total_itr = None | |
self.loop_bar = None | |
self.loop_itr = None | |
self.keys = keys | |
self.img_fn = img_fn | |
self.graphs = {k: [] for k in self.keys} | |
self.graphs["_iter"] = [] | |
self.next_interval = 0 | |
self.summary = reporter.DictSummary() | |
def __call__(self, trainer): | |
training_length = trainer.stop_trigger.get_training_length() | |
length, unit = training_length | |
if self.total_bar is None: | |
self.total_bar = master_bar(range(ceil(length / self._update_interval))) | |
self.total_itr = iter(self.total_bar) | |
self.total_bar.names = self.keys | |
keys = [k for k in self.keys if k in trainer.observation] | |
self.summary.add({k: trainer.observation[k] for k in keys}) | |
iteration = trainer.updater.iteration | |
if iteration >= self.next_interval: | |
self.next_interval += self._update_interval | |
# progress bar | |
try: | |
next(self.total_itr) | |
except: | |
pass | |
# loss graph | |
self.loop_bar = progress_bar(range(self._update_interval), parent=self.total_bar) | |
self.loop_itr = iter(self.loop_bar) | |
stats = self.summary.compute_mean() | |
stats_cpu = {} | |
for name, value in six.iteritems(stats): | |
stats_cpu[name] = float(value) # copy to CPU | |
self.summary = reporter.DictSummary() | |
graphs = [] | |
self.graphs["_iter"].append(iteration) | |
for k in keys: | |
self.graphs[k].append(stats_cpu[k]) | |
graphs.append([self.graphs["_iter"], self.graphs[k]]) | |
self.total_bar.update_graph(graphs) | |
# sample image | |
if self.img_fn is not None: | |
img = self.img_fn() | |
buf = BytesIO() | |
PIL.Image.fromarray(img).save(buf, "png") | |
if not hasattr(self, 'outImg'): | |
self.outImg = display.display(display.Image(data=buf.getvalue()), display_id=1) | |
else: | |
self.outImg.update(display.Image(data=buf.getvalue())) | |
next(self.loop_itr) | |
def finalize(self): | |
if self.total_itr is not None: | |
try: | |
next(self.total_itr) | |
except: | |
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_image(generator): | |
def _gen(): | |
return generator() # Numpy 2d image | |
return _gen | |
trainer.extend(Progress(update_interval=100, keys=["loss"], img_fn=make_image(gen))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment