Last active
April 1, 2020 20:50
-
-
Save alsrgv/b25f1d021c53ea6ae725b3372d90d62a to your computer and use it in GitHub Desktop.
Receptive fields for Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import graphviz | |
import numpy as np | |
import keras | |
def gather_layer_stats(layer_dict, layer, r, s): | |
lr, ls = None, None | |
if hasattr(layer, 'kernel_size'): | |
assert layer.kernel_size[0] == layer.kernel_size[1] | |
assert layer.strides[0] == layer.strides[1] | |
lr = layer.kernel_size[0] | |
ls = layer.strides[0] | |
r += (lr - 1) * s | |
s *= ls | |
elif hasattr(layer, 'pool_size'): | |
assert layer.pool_size[0] == layer.pool_size[1] | |
assert layer.strides[0] == layer.strides[1] | |
lr = layer.pool_size[0] | |
ls = layer.strides[0] | |
r += (lr - 1) * s | |
s *= ls | |
elif type(layer) == keras.layers.Lambda: | |
assert layer.input_shape[1] / layer.output_shape[1] == \ | |
layer.input_shape[2] / layer.output_shape[2] | |
r *= layer.input_shape[1] / layer.output_shape[1] | |
s *= layer.input_shape[1] / layer.output_shape[1] | |
lr = r | |
ls = s | |
elif type(layer) not in [keras.layers.BatchNormalization, | |
keras.layers.Activation]: | |
lr, ls = 1, 1 | |
if lr and ls: | |
if layer in layer_dict: | |
layer_dict[layer]['r'] = max(layer_dict[layer]['r'], r) | |
assert layer_dict[layer]['s'] == s | |
else: | |
layer_dict[layer] = dict(lr=lr, ls=ls, r=r, s=s) | |
if layer.outbound_nodes: | |
for node in layer.outbound_nodes: | |
gather_layer_stats(layer_dict, node.outbound_layer, r, s) | |
def remove_batch_size(shape): | |
if type(shape) == list: | |
return [remove_batch_size(s) for s in shape] | |
return tuple(shape[1:]) | |
def draw_nodes(g, layer_dict): | |
for layer, kv in layer_dict.items(): | |
input_shape = remove_batch_size(layer.input_shape) | |
output_shape = remove_batch_size(layer.output_shape) | |
rs = np.prod(remove_batch_size(layer.output_shape)) | |
label = '%s | %s / %s | %s / %s | %s -\> %s | %s' % (layer.name, kv['lr'], kv['ls'], kv['r'], kv['s'], | |
input_shape, output_shape, rs) | |
g.node(layer.name, label) | |
def draw_edges(g, layer_dict, layer, last_graph_layer=None, drawn=None): | |
if drawn is None: | |
drawn = {} | |
if layer in layer_dict: | |
if last_graph_layer and (last_graph_layer, layer) not in drawn: | |
g.edge(last_graph_layer.name, layer.name) | |
drawn[(last_graph_layer, layer)] = 1 | |
last_graph_layer = layer | |
if layer.outbound_nodes: | |
for node in layer.outbound_nodes: | |
draw_edges(g, layer_dict, node.outbound_layer, | |
last_graph_layer, drawn) | |
def print_receptive_fields(name, model, reference_shape): | |
layer = model.input_layers[0] | |
assert reference_shape[0] == reference_shape[1] | |
r = s = reference_shape[0] / layer.input_shape[1] | |
layer_dict = {layer: dict(lr=r, ls=s, r=r, s=s)} | |
gather_layer_stats(layer_dict, layer, r, s) | |
g = graphviz.Digraph(filename=name, format='png', | |
node_attr={'shape': 'record'}) | |
g.node_attr.update(fillcolor='lightblue2', style='filled') | |
draw_nodes(g, layer_dict) | |
draw_edges(g, layer_dict, layer) | |
g.render() | |
model = keras.applications.ResNet50(weights=None) | |
print_receptive_fields('resnet50.gv', model, reference_shape=(224, 224)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
What keras version is used?
On Keras
2.1.3
:I get:
AttributeError: 'InputLayer' object has no attribute 'outbound_nodes'
Update:
It's now
_outbound_nodes
https://stackoverflow.com/questions/48485937/attributeerror-inputlayer-object-has-no-attribute-inbound-nodes