Skip to content

Instantly share code, notes, and snippets.

@joelthchao
Last active August 31, 2021 18:02
Show Gist options
  • Save joelthchao/ef6caa586b647c3c032a4f84d52e3a11 to your computer and use it in GitHub Desktop.
Save joelthchao/ef6caa586b647c3c032a4f84d52e3a11 to your computer and use it in GitHub Desktop.
Keras uses TensorBoard Callback with train_on_batch
import numpy as np
import tensorflow as tf
from keras.callbacks import TensorBoard
from keras.layers import Input, Dense
from keras.models import Model
def write_log(callback, names, logs, batch_no):
for name, value in zip(names, logs):
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value
summary_value.tag = name
callback.writer.add_summary(summary, batch_no)
callback.writer.flush()
net_in = Input(shape=(3,))
net_out = Dense(1)(net_in)
model = Model(net_in, net_out)
model.compile(loss='mse', optimizer='sgd', metrics=['mae'])
log_path = './logs'
callback = TensorBoard(log_path)
callback.set_model(model)
train_names = ['train_loss', 'train_mae']
val_names = ['val_loss', 'val_mae']
for batch_no in range(100):
X_train, Y_train = np.random.rand(32, 3), np.random.rand(32, 1)
logs = model.train_on_batch(X_train, Y_train)
write_log(callback, train_names, logs, batch_no)
if batch_no % 10 == 0:
X_val, Y_val = np.random.rand(32, 3), np.random.rand(32, 1)
logs = model.train_on_batch(X_val, Y_val)
write_log(callback, val_names, logs, batch_no//10)
@dmitrysarov
Copy link

dmitrysarov commented Sep 30, 2017

What happens when i set several different model to one TensorBoard log path?

log_path = './logs'
callback = TensorBoard(log_path)
callback.set_model(model)

How can we exploit that?

@trianam
Copy link

trianam commented Dec 4, 2017

On line 34 don't you mean to write
logs = model.test_on_batch(X_val, Y_val)
instead of
logs = model.train_on_batch(X_val, Y_val)
?

@NiftyGrimoire
Copy link

tf.Summary() does not seem to have the value element?

@erenon
Copy link

erenon commented Sep 30, 2018

Here's a simpler solution, which uses the TensorBoard callback directly:
https://gist.github.com/erenon/91f526302cd8e9d21b73f24c0f9c4bb8

@piyush01123
Copy link

What happens when i set several different model to one TensorBoard log path?

log_path = './logs'
callback = TensorBoard(log_path)
callback.set_model(model)

How can we exploit that?

I think the answer to that is to not use callbacks at all.
Also, I think in general if you were using train_on_batch you should not write logs via callback but rather, you should do it this way:

from tensorflow.keras import backend as K

loss = model.train_on_batch(X, Y)
tf.summary.scalar('loss', loss)
# add other stuff to your tf.sumary object

merged = tf.summary.merge_all()
summary = sess.run(merged)
sess = K.get_session()
writer.add_summary(summary)

@rounakskm
Copy link

Hello, I am trying to upgrade my code to Tensorflow 2.0, it does not have tf.Summary().
So could anyone tell me how to create the write_log function which will allow me to visualize using tensorboard
Thank you

@AliSaeed86
Copy link

Here's a simpler solution, which uses the TensorBoard callback directly:
https://gist.github.com/erenon/91f526302cd8e9d21b73f24c0f9c4bb8

thanks bro, it really helped

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment