Last active
April 21, 2019 22:36
-
-
Save mkroutikov/d9854611edda15b1262ab9850878ba35 to your computer and use it in GitHub Desktop.
Hack to teach tensorboardX to write summaries to cloud locations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Super dirty trick to replace io of tensorboardX.SummaryWriter with cloud-capable one (from tensorflow.gfile) | |
''' | |
import tensorboardX | |
import tensorflow.gfile as gio | |
from unittest import mock | |
import logging | |
import time | |
def hacked_open(*av, **kav): | |
# need flush suppressor, because tesorboardX is | |
# flushing on every event. With cloud URLs this | |
# causes warnings about re-writing cloud file too | |
# often | |
return FlushSuppressor(gio.Open(*av, **kav)) | |
class FlushSuppressor: | |
''' wraps standard file and suppresses too frequent flushes''' | |
def __init__(self, engine, min_flush_seconds=60): | |
self._engine = engine | |
self._min_flush_seconds = min_flush_seconds | |
self._last_flush = None | |
def write(self, data): | |
return self._engine.write(data) | |
def close(self): | |
return self._engine.close() | |
def flush(self): | |
if self._last_flush is None: | |
self._last_flush = time.time() | |
elif time.time()-self._last_flush >= self._min_flush_seconds: | |
self._last_flush = time.time() | |
logging.debug('Flushing IO') | |
self._engine.flush() | |
def __enter__(self): | |
return self | |
def __exit__(self, *av, **kav): | |
self._engine.exit(*av, **kav) | |
class SummaryWriter(tensorboardX.SummaryWriter): | |
'''Summary writer capable of writing to cloud URLs, like gs://my-training-bucket/experiment3''' | |
def __init__(self, output_dir): | |
with mock.patch('tensorboardX.record_writer.open', hacked_open): | |
with mock.patch('tensorboardX.event_file_writer.directory_check', gio.MakeDirs): | |
tensorboardX.SummaryWriter.__init__(self, output_dir) | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.DEBUG) | |
sw = SummaryWriter('gs://my-bucket/my-file.txt') | |
for i in range(10000): | |
sw.add_scalar('count', i, i) | |
time.sleep(0.5) | |
sw.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment