Skip to content

Instantly share code, notes, and snippets.

@jbcdnr
Last active February 19, 2021 09:58
Show Gist options
  • Save jbcdnr/50d416dec405d88f576ba497c418b04e to your computer and use it in GitHub Desktop.
Save jbcdnr/50d416dec405d88f576ba497c418b04e to your computer and use it in GitHub Desktop.
Record input and output of any PyTorch layer
"""
# Quick start
wget -O recorder.py https://gist.github.com/jbcdnr/50d416dec405d88f576ba497c418b04e/raw/
# Small example
from recorder import trace_layers
import torchvision
resnet = torchvision.models.resnet18()
with trace_layers(resnet, "layer1[0].conv1", "layer1[0].conv2") as recorded_values:
resnet(torch.randn(1, 3, 224, 224))
recorded_values["layer1[0].conv1"].output
"""
from contextlib import contextmanager
from collections import namedtuple, OrderedDict
import torch.nn as nn
Record = namedtuple("Record", "input output")
class RecorderLayer(nn.Module):
def __init__(self, original_layer, record_callback):
super().__init__()
self.original_layer = original_layer
self.record_callback = record_callback
def forward(self, *args, **kwargs):
input = (args, kwargs)
output = self.original_layer(*args, **kwargs)
self.record_callback(Record(input, output))
return output
def get_layer(model, name):
return eval(f"model.{name}")
def set_layer(model, name, new_layer):
return exec(f"model.{name} = new_layer")
@contextmanager
def trace_layers(model, *layer_names):
recorded_values = OrderedDict([(l, None) for l in layer_names])
def record_callback(name):
def callback(value):
assert recorded_values[name] is None, f"'{name}' already recorded."
recorded_values[name] = value
return callback
try:
# Replace the specified layers with RecorderLayers
for layer_name in layer_names:
layer = get_layer(model, layer_name)
set_layer(model, layer_name, RecorderLayer(layer, record_callback(layer_name)))
yield recorded_values
finally:
# Set back the original layers
for layer_name in layer_names:
original_layer = get_layer(model, layer_name).original_layer
set_layer(model, layer_name, original_layer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment