Last active
May 2, 2021 14:26
-
-
Save SiLiKhon/3965c967c3283feccc79822e6252b34c to your computer and use it in GitHub Desktop.
Ugly code to monitor layer activations when training tf keras models with `model.fit` or similar.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Disclaimer: this is just an ugly solution to monitoring layer activations | |
# when training a model in tf keras. | |
# this code is executed right after the model is created. | |
class Callback(tf.keras.callbacks.Callback): | |
def __init__(self, logging_data): | |
super().__init__() | |
# `logging_data` is a namedtuple of the follwoing type: | |
# LoggingData = namedtuple("LoggingData", ['get_val_dataset_fn', 'get_writer_fn', 'freq']) | |
### get_val_dataset_fn - a function to get validation data | |
### get_writer_fn - a function to get the tf writer object | |
### freq - integer, write interval | |
# I'm using these `get_val_dataset_fn` and `get_writer_fn` functions instead of the dataset and | |
# the writer themselves, because they are not defined at model creation time in my code. | |
self.logging_data = logging_data | |
# I create a "mirror" model whose output will be the activations I need to monitor | |
self.attention_model = tf.keras.Model( | |
inputs=model.inputs, | |
outputs=attention_tensors # these are the tensors I need to monitor, defined | |
# earlier while creating the main model | |
) | |
def on_epoch_end(self, epoch, logs=None): | |
if epoch % self.logging_data.freq == 0: | |
writer = self.logging_data.get_writer_fn() | |
with writer.as_default(): | |
# So when I want to monitor the values I just call the "mirror" model. | |
# (Note that you may want to avoid using `predict` if for example | |
# you want to monitor train behaviour.) | |
attention_val = self.attention_model.predict( | |
# `get_val_dataset_fn()` returns a tf Dataset with both inputs and targets, so I map | |
# a lambda to get only inputs. | |
self.logging_data.get_val_dataset_fn().map(lambda x, y: x) | |
) | |
for tensor, value in zip(attention_tensors, attention_val): | |
# in my case the tensors I monitor have some nontrivial shape, so I unwrap them with `.ravel()` | |
tf.summary.histogram(name=tensor.name, data=value.ravel(), step=epoch) | |
writer.flush() | |
# Later, once the TensorBoard callback is created I define the `get_writer_fn` as follows: | |
get_writer_fn = lambda: tensorboard_callback._get_writer(tensorboard_callback._validation_run_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment