Skip to content

Instantly share code, notes, and snippets.

@SiLiKhon
Last active May 2, 2021 14:26
Show Gist options
  • Save SiLiKhon/3965c967c3283feccc79822e6252b34c to your computer and use it in GitHub Desktop.
Save SiLiKhon/3965c967c3283feccc79822e6252b34c to your computer and use it in GitHub Desktop.
Ugly code to monitor layer activations when training tf keras models with `model.fit` or similar.
# Disclaimer: this is just an ugly solution to monitoring layer activations
# when training a model in tf keras.
# this code is executed right after the model is created.
class Callback(tf.keras.callbacks.Callback):
def __init__(self, logging_data):
super().__init__()
# `logging_data` is a namedtuple of the follwoing type:
# LoggingData = namedtuple("LoggingData", ['get_val_dataset_fn', 'get_writer_fn', 'freq'])
### get_val_dataset_fn - a function to get validation data
### get_writer_fn - a function to get the tf writer object
### freq - integer, write interval
# I'm using these `get_val_dataset_fn` and `get_writer_fn` functions instead of the dataset and
# the writer themselves, because they are not defined at model creation time in my code.
self.logging_data = logging_data
# I create a "mirror" model whose output will be the activations I need to monitor
self.attention_model = tf.keras.Model(
inputs=model.inputs,
outputs=attention_tensors # these are the tensors I need to monitor, defined
# earlier while creating the main model
)
def on_epoch_end(self, epoch, logs=None):
if epoch % self.logging_data.freq == 0:
writer = self.logging_data.get_writer_fn()
with writer.as_default():
# So when I want to monitor the values I just call the "mirror" model.
# (Note that you may want to avoid using `predict` if for example
# you want to monitor train behaviour.)
attention_val = self.attention_model.predict(
# `get_val_dataset_fn()` returns a tf Dataset with both inputs and targets, so I map
# a lambda to get only inputs.
self.logging_data.get_val_dataset_fn().map(lambda x, y: x)
)
for tensor, value in zip(attention_tensors, attention_val):
# in my case the tensors I monitor have some nontrivial shape, so I unwrap them with `.ravel()`
tf.summary.histogram(name=tensor.name, data=value.ravel(), step=epoch)
writer.flush()
# Later, once the TensorBoard callback is created I define the `get_writer_fn` as follows:
get_writer_fn = lambda: tensorboard_callback._get_writer(tensorboard_callback._validation_run_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment