Skip to content

Instantly share code, notes, and snippets.

@kretes
Last active September 6, 2021 06:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kretes/f1c4261d0152fb5ab9fc82c4ff7f1c85 to your computer and use it in GitHub Desktop.
Save kretes/f1c4261d0152fb5ab9fc82c4ff7f1c85 to your computer and use it in GitHub Desktop.
tensorflow callback that prints some basic statistics about weights
class ModelWeightsPrinter(Callback):
def __init__(self, model) -> None:
super().__init__()
self.model = model
def print_stats(self, hist):
if hist:
allw = np.hstack([x.flatten() for x in self.model.get_weights()])
h = np.histogram(allw, bins=np.linspace(-1, 1, 5))
print("weights_histogram")
print(allw.sum(), h[0], h[1])
for i, x in enumerate(self.model.get_weights()):
print("--------------------")
print(i, x.sum(), x.shape, np.histogram(x, bins=np.linspace(-1, 1, 5)))
else:
sum_of_w = sum(map(lambda x: x.sum(), self.model.get_weights()))
print()
print(f"weights {sum_of_w}")
def on_train_begin(self, logs=None):
self.print_stats(hist=True)
def on_epoch_end(self, epoch, logs=None):
self.print_stats(hist=True)
def on_train_batch_end(self, batch, logs=None):
self.print_stats(hist=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment