Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Created July 30, 2023 17:47
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save madebyollin/e6e217a77db94e2a960ab6ccbd627db9 to your computer and use it in GitHub Desktop.
Save madebyollin/e6e217a77db94e2a960ab6ccbd627db9 to your computer and use it in GitHub Desktop.
Helper for logging output activation-map statistics for a PyTorch module, using forward hooks
def summarize_tensor(x):
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
class ModelActivationPrinter:
def __init__(self, module, submodules_to_log):
self.id_to_name = {
id(module): str(name) for name, module in module.named_modules()
}
self.submodules = submodules_to_log
self.hooks = []
def __enter__(self, *args, **kwargs):
def log_activations(m, m_in, m_out):
label = self.id_to_name.get(id(m), "(unnamed)") + " output"
if isinstance(m_out, (tuple, list)):
m_out = m_out[0]
label += "[0]"
print(label.ljust(48) + summarize_tensor(m_out))
for m in self.submodules:
self.hooks.append(m.register_forward_hook(log_activations))
return self
def __exit__(self, *args, **kwargs):
for hook in self.hooks:
hook.remove()
if __name__ == "__main__":
import torch
model = torch.nn.Sequential(
torch.nn.Linear(1, 64), torch.nn.ReLU(), torch.nn.Linear(64, 1)
)
with ModelActivationPrinter(model, model):
y = model(torch.zeros(1, 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment