Skip to content

Instantly share code, notes, and snippets.

@Cospel
Created June 4, 2020 13:58
Show Gist options
  • Save Cospel/2cf1b3d6323b9763002bc9e7e23a16c9 to your computer and use it in GitHub Desktop.
Save Cospel/2cf1b3d6323b9763002bc9e7e23a16c9 to your computer and use it in GitHub Desktop.
BatchRenormCallback.py
class BatchRenormCallback(tf.keras.callbacks.Callback):
def __init__(self, log_dir, bvalues):
super().__init__()
self.writer = tf.summary.create_file_writer(log_dir)
self.bvalues = bvalues
def on_epoch_begin(self, epoch, logs=None):
if epoch in self.bvalues:
self.change_renorm(epoch)
def change_renorm(self, epoch):
print(f"Changing renorm clipping values epoch {epoch}", self.bvalues[epoch])
renorm_clipping = self.bvalues[epoch]
with self.writer.as_default():
tf.summary.scalar("renorm_rmax", renorm_clipping["rmax"], epoch)
tf.summary.scalar("renorm_dmax", renorm_clipping["dmax"], epoch)
for layer in self.model.layers:
if isinstance(layer, tf.keras.Model):
for layer2 in layer.layers:
if isinstance(layer2, BatchNormalization):
layer2.renorm_clipping = renorm_clipping
if isinstance(layer, BatchNormalization):
layer.renorm_clipping = renorm_clipping
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment