Skip to content

Instantly share code, notes, and snippets.

@jaekookang
Last active November 30, 2020 01:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaekookang/7e2ca4dc2b1ab10dbb80b9e65ca91179 to your computer and use it in GitHub Desktop.
Save jaekookang/7e2ca4dc2b1ab10dbb80b9e65ca91179 to your computer and use it in GitHub Desktop.
TensorFlow2.x Keras Custom Train Logger Callback
'''
This code snippet was developed based on
- https://github.com/keras-team/keras/issues/2850#issuecomment-222494059
How to use:
```
logger = NBatchLogger(n_display, n_epoch, log_dir)
model.fit(data_generator,
epochs=n_epoch,
callbacks=[logger],
verbose=0, # this needs to be set "0" to properly display logger during training
...)
# This will display training progress based on `n_display`
# log file will be written under `log_dir` (eg. train.log or train_mnist.log)
2020-04-20
2020-11-18 edited
2020-11-29 fstring updated
```
'''
import os
from time import time, strftime, gmtime
import tensorflow as tf
tfk = tf.keras
tfkc = tfk.callbacks
class NBatchLogger(tfkc.Callback):
'''A Logger that log average performance per `display` steps.'''
def __init__(self, n_display, max_epoch, save_dir=None, suffix=None, silent=False):
self.epoch = 0
self.display = n_display
self.max_epoch = max_epoch
self.logs = {}
self.save_dir = save_dir
self.silent = silent
if self.save_dir is not None:
assert os.path.exists(self.save_dir), Exception(
f'Path:{self.save_dir} does not exist!')
fname = 'train.log'
if suffix is not None:
fname = f'train_{suffix}.log'
self.fid = open(os.path.join(save_dir, fname), 'w')
self.t0 = time()
def on_train_begin(self, logs={}):
logs = logs or self.logs
txt = f'=== Started at {self.get_time()} ==='
self.write_log(txt)
if not self.silent:
print(txt)
def on_epoch_end(self, epoch, logs={}):
self.epoch += 1
fstr = ' {} | Epoch: {:0{}d}/{:0{}d} | '
precision = len(str(self.max_epoch))
if (self.epoch % self.display == 0) | (self.epoch == 1):
txt = fstr.format(self.get_time(), self.epoch, precision, self.max_epoch, precision)
if not self.silent:
print(txt, end='')
for i, key in enumerate(logs.keys()):
if (i+1) == len(logs.keys()):
_txt = f'{key}={logs[key]:4f}'
if not self.silent:
print(_txt, end='\n')
else:
_txt = f'{key}={logs[key]:4f} '
if not self.silent:
print(_txt, end='')
txt = txt + _txt
self.write_log(txt)
self.logs = logs
def on_train_end(self, logs={}):
logs = logs or self.logs
t1 = time()
txt = f'=== Time elapsed: {(t1-self.t0)/60:.4f} min ==='
if not self.silent:
print(txt)
self.write_log(txt)
def get_time(self):
return strftime('%Y-%m-%d %Hh:%Mm:%Ss', gmtime())
def write_log(self, txt):
if self.save_dir is not None:
self.fid.write(txt+'\n')
self.fid.flush()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment