Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@erenon
Last active August 15, 2023 21:52
Show Gist options
  • Star 16 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save erenon/91f526302cd8e9d21b73f24c0f9c4bb8 to your computer and use it in GitHub Desktop.
Save erenon/91f526302cd8e9d21b73f24c0f9c4bb8 to your computer and use it in GitHub Desktop.
# This example shows how to use keras TensorBoard callback
# with model.train_on_batch
import tensorflow.keras as keras
# Setup the model
model = keras.models.Sequential()
model.add(...) # Add your layers
model.compile(...) # Compile as usual
batch_size=256
# Create the TensorBoard callback,
# which we will drive manually
tensorboard = keras.callbacks.TensorBoard(
log_dir='/tmp/my_tf_logs',
histogram_freq=0,
batch_size=batch_size,
write_graph=True,
write_grads=True
)
tensorboard.set_model(model)
# Transform train_on_batch return value
# to dict expected by on_batch_end callback
def named_logs(model, logs):
result = {}
for l in zip(model.metrics_names, logs):
result[l[0]] = l[1]
return result
# Run training batches, notify tensorboard at the end of each epoch
for batch_id in range(1000):
x_train,y_train = create_training_data(batch_size)
logs = model.train_on_batch(x_train, y_train)
tensorboard.on_epoch_end(batch_id, named_logs(model, logs))
tensorboard.on_train_end(None)
@mirkow
Copy link

mirkow commented Oct 20, 2018

for which keras version is this?
For me train_on_batch() returns a single float(the loss I guess)

@ArashHosseini
Copy link

it returns Scalar training loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics)

@jaassoon
Copy link

jaassoon commented Nov 12, 2018

for which keras version is this?
For me train_on_batch() returns a single float(the loss I guess)

You can convert it to array, as below:
tensorboard.on_epoch_end(batch_id, named_logs(model, [logs]))

@albertorb
Copy link

albertorb commented Sep 30, 2019

This tensorboard.on_epoch_end(batch_id, named_logs(model, logs))
should be instead
tensorboard.on_batch_end(batch_id, named_logs(model, logs))

To avoid the logs['size'] issue, just include your batch_size:

  result = {'size':batch_size}
  for l in zip(model.metrics_names, logs):
    result[l[0]] = l[1]
  return result```

@kaiche12
Copy link

How do you apply early stoppage when valid loss does not increase

@garryyan2
Copy link

This tensorboard.on_epoch_end(batch_id, named_logs(model, logs)) should be instead tensorboard.on_batch_end(batch_id, named_logs(model, logs))

I was very uncomfortable on using tensorboard.on_epoch_end() instead of tensorboard.on_batch_end() as well. I couldn't get the correct result if I just made the simple switch. I needed to change callbacks.tensorboard update_freq from the default "epoch" to "batch" as:

tensorboard = keras.callbacks.TensorBoard(
log_dir='/tmp/my_tf_logs',
histogram_freq=0,
batch_size=batch_size,
write_graph=True,
write_grads=True,
update_freq="batch"
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment