Skip to content

Instantly share code, notes, and snippets.

@pangyuteng
Forked from N-McA/keras_spatial_bias.py
Last active November 13, 2019 19:09
Show Gist options
  • Save pangyuteng/8f4f7c09b490e1baaef852d07105db77 to your computer and use it in GitHub Desktop.
Save pangyuteng/8f4f7c09b490e1baaef852d07105db77 to your computer and use it in GitHub Desktop.
Concatenates the (x, y) coordinate normalised to 0-1 to each spatial location in the image. Allows a network to learn spatial bias. Has been explored in at least one paper, "An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution" https://arxiv.org/abs/1807.03247
class ConcatSpatialCoordinate(Layer):
def __init__(self, **kwargs):
"""Concatenates the (x, y) coordinate normalised to 0-1 to each spatial location in the image.
Allows a network to learn spatial bias. Has been explored in at least one paper,
"An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution"
https://arxiv.org/abs/1807.03247
Improves performance where spatial bias is appropriate.
Works with dynamic shapes.
Aurther/Source: N-McA, https://gist.github.com/N-McA/9bd3a81d3062340e4affaaaaad332107
"""
if kb.image_data_format() != 'channels_last':
raise Exception((
"Only compatible with"
" kb.image_data_format() == 'channels_last'"))
super(ConcatSpatialCoordinate, self).__init__(**kwargs)
def build(self, input_shape):
super(ConcatSpatialCoordinate, self).build(input_shape)
def call(self, x):
dynamic_input_shape = kb.shape(x)
batch_size = dynamic_input_shape[0]
w = dynamic_input_shape[-3]
h = dynamic_input_shape[-2]
bias = _kb_grid_coords(width=w, height=h)
bias = kb.expand_dims(bias, 0)
bias = kb.tile(bias, [batch_size,1,1,1])
return kb.concatenate([x, bias], axis=-1)
def compute_output_shape(self, input_shape):
batch_size, w, h, channels = input_shape
return (batch_size, w, h, channels + 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment