Skip to content

Instantly share code, notes, and snippets.

@ferrine
Created April 11, 2019 11:46
Show Gist options
  • Save ferrine/0df6ae873ad0308e2e63d24f7ae37e6b to your computer and use it in GitHub Desktop.
Save ferrine/0df6ae873ad0308e2e63d24f7ae37e6b to your computer and use it in GitHub Desktop.
Tensorboard sacred observer
import sacred.observers
import tensorboardX
import os
class TensorboardObserver(sacred.observers.FileStorageObserver, tensorboardX.SummaryWriter):
VERSION = "TensorboardObserver-0.0.1"
def __init__(self, basedir, resource_dir=None, source_dir=None,
template=None, priority=sacred.observers.file_storage.DEFAULT_FILE_STORAGE_PRIORITY,
config2name=lambda c: "", **kwargs):
sacred.observers.FileStorageObserver.__init__(
self,
basedir=basedir,
resource_dir=resource_dir,
source_dir=source_dir,
template=template,
priority=priority
)
self._tb_kwargs = kwargs
self.config2name = config2name
def started_event(self, ex_info, command, host_info, start_time, config,
meta_info, _id):
_id = super().started_event(
ex_info=ex_info, command=command, host_info=host_info, start_time=start_time, config=config,
meta_info=meta_info, _id=_id
)
log_dir = os.path.join(self.dir, "_", self.config2name(config))
tensorboardX.SummaryWriter.__init__(self, log_dir=log_dir, **self._tb_kwargs)
del self._tb_kwargs
return _id
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.basedir == other.basedir
return False
def completed_event(self, stop_time, result):
super().completed_event(stop_time=stop_time, result=result)
self.close()
def interrupted_event(self, interrupt_time, status):
super().interrupted_event(interrupt_time=interrupt_time, status=status)
self.close()
def failed_event(self, fail_time, fail_trace):
super().failed_event(fail_time=fail_time, fail_trace=fail_trace)
self.close()
@classmethod
def create(cls, basedir, resource_dir=None, source_dir=None,
template=None, priority=sacred.observers.file_storage.DEFAULT_FILE_STORAGE_PRIORITY, **kwargs):
if not os.path.exists(basedir):
os.makedirs(basedir)
resource_dir = resource_dir or os.path.join(basedir, '_resources')
source_dir = source_dir or os.path.join(basedir, '_sources')
if template is not None:
if not os.path.exists(template):
raise FileNotFoundError("Couldn't find template file '{}'"
.format(template))
else:
template = os.path.join(basedir, 'template.html')
if not os.path.exists(template):
template = None
return cls(basedir, resource_dir, source_dir, template, priority, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment