Last active
November 30, 2020 01:46
-
-
Save jaekookang/7e2ca4dc2b1ab10dbb80b9e65ca91179 to your computer and use it in GitHub Desktop.
TensorFlow2.x Keras Custom Train Logger Callback
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
This code snippet was developed based on | |
- https://github.com/keras-team/keras/issues/2850#issuecomment-222494059 | |
How to use: | |
``` | |
logger = NBatchLogger(n_display, n_epoch, log_dir) | |
model.fit(data_generator, | |
epochs=n_epoch, | |
callbacks=[logger], | |
verbose=0, # this needs to be set "0" to properly display logger during training | |
...) | |
# This will display training progress based on `n_display` | |
# log file will be written under `log_dir` (eg. train.log or train_mnist.log) | |
2020-04-20 | |
2020-11-18 edited | |
2020-11-29 fstring updated | |
``` | |
''' | |
import os | |
from time import time, strftime, gmtime | |
import tensorflow as tf | |
tfk = tf.keras | |
tfkc = tfk.callbacks | |
class NBatchLogger(tfkc.Callback): | |
'''A Logger that log average performance per `display` steps.''' | |
def __init__(self, n_display, max_epoch, save_dir=None, suffix=None, silent=False): | |
self.epoch = 0 | |
self.display = n_display | |
self.max_epoch = max_epoch | |
self.logs = {} | |
self.save_dir = save_dir | |
self.silent = silent | |
if self.save_dir is not None: | |
assert os.path.exists(self.save_dir), Exception( | |
f'Path:{self.save_dir} does not exist!') | |
fname = 'train.log' | |
if suffix is not None: | |
fname = f'train_{suffix}.log' | |
self.fid = open(os.path.join(save_dir, fname), 'w') | |
self.t0 = time() | |
def on_train_begin(self, logs={}): | |
logs = logs or self.logs | |
txt = f'=== Started at {self.get_time()} ===' | |
self.write_log(txt) | |
if not self.silent: | |
print(txt) | |
def on_epoch_end(self, epoch, logs={}): | |
self.epoch += 1 | |
fstr = ' {} | Epoch: {:0{}d}/{:0{}d} | ' | |
precision = len(str(self.max_epoch)) | |
if (self.epoch % self.display == 0) | (self.epoch == 1): | |
txt = fstr.format(self.get_time(), self.epoch, precision, self.max_epoch, precision) | |
if not self.silent: | |
print(txt, end='') | |
for i, key in enumerate(logs.keys()): | |
if (i+1) == len(logs.keys()): | |
_txt = f'{key}={logs[key]:4f}' | |
if not self.silent: | |
print(_txt, end='\n') | |
else: | |
_txt = f'{key}={logs[key]:4f} ' | |
if not self.silent: | |
print(_txt, end='') | |
txt = txt + _txt | |
self.write_log(txt) | |
self.logs = logs | |
def on_train_end(self, logs={}): | |
logs = logs or self.logs | |
t1 = time() | |
txt = f'=== Time elapsed: {(t1-self.t0)/60:.4f} min ===' | |
if not self.silent: | |
print(txt) | |
self.write_log(txt) | |
def get_time(self): | |
return strftime('%Y-%m-%d %Hh:%Mm:%Ss', gmtime()) | |
def write_log(self, txt): | |
if self.save_dir is not None: | |
self.fid.write(txt+'\n') | |
self.fid.flush() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment