Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Receptive fields for Keras
import graphviz
import numpy as np
import keras
def gather_layer_stats(layer_dict, layer, r, s):
lr, ls = None, None
if hasattr(layer, 'kernel_size'):
assert layer.kernel_size[0] == layer.kernel_size[1]
assert layer.strides[0] == layer.strides[1]
lr = layer.kernel_size[0]
ls = layer.strides[0]
r += (lr - 1) * s
s *= ls
elif hasattr(layer, 'pool_size'):
assert layer.pool_size[0] == layer.pool_size[1]
assert layer.strides[0] == layer.strides[1]
lr = layer.pool_size[0]
ls = layer.strides[0]
r += (lr - 1) * s
s *= ls
elif type(layer) == keras.layers.Lambda:
assert layer.input_shape[1] / layer.output_shape[1] == \
layer.input_shape[2] / layer.output_shape[2]
r *= layer.input_shape[1] / layer.output_shape[1]
s *= layer.input_shape[1] / layer.output_shape[1]
lr = r
ls = s
elif type(layer) not in [keras.layers.BatchNormalization,
lr, ls = 1, 1
if lr and ls:
if layer in layer_dict:
layer_dict[layer]['r'] = max(layer_dict[layer]['r'], r)
assert layer_dict[layer]['s'] == s
layer_dict[layer] = dict(lr=lr, ls=ls, r=r, s=s)
if layer.outbound_nodes:
for node in layer.outbound_nodes:
gather_layer_stats(layer_dict, node.outbound_layer, r, s)
def remove_batch_size(shape):
if type(shape) == list:
return [remove_batch_size(s) for s in shape]
return tuple(shape[1:])
def draw_nodes(g, layer_dict):
for layer, kv in layer_dict.items():
input_shape = remove_batch_size(layer.input_shape)
output_shape = remove_batch_size(layer.output_shape)
rs =
label = '%s | %s / %s | %s / %s | %s -\> %s | %s' % (, kv['lr'], kv['ls'], kv['r'], kv['s'],
input_shape, output_shape, rs)
g.node(, label)
def draw_edges(g, layer_dict, layer, last_graph_layer=None, drawn=None):
if drawn is None:
drawn = {}
if layer in layer_dict:
if last_graph_layer and (last_graph_layer, layer) not in drawn:
drawn[(last_graph_layer, layer)] = 1
last_graph_layer = layer
if layer.outbound_nodes:
for node in layer.outbound_nodes:
draw_edges(g, layer_dict, node.outbound_layer,
last_graph_layer, drawn)
def print_receptive_fields(name, model, reference_shape):
layer = model.input_layers[0]
assert reference_shape[0] == reference_shape[1]
r = s = reference_shape[0] / layer.input_shape[1]
layer_dict = {layer: dict(lr=r, ls=s, r=r, s=s)}
gather_layer_stats(layer_dict, layer, r, s)
g = graphviz.Digraph(filename=name, format='png',
node_attr={'shape': 'record'})
g.node_attr.update(fillcolor='lightblue2', style='filled')
draw_nodes(g, layer_dict)
draw_edges(g, layer_dict, layer)
model = keras.applications.ResNet50(weights=None)
print_receptive_fields('resnet50.gv', model, reference_shape=(224, 224))
Copy link

mrgloom commented Jun 23, 2018

What keras version is used?

On Keras 2.1.3:
I get: AttributeError: 'InputLayer' object has no attribute 'outbound_nodes'

It's now _outbound_nodes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment