Skip to content

Instantly share code, notes, and snippets.

@N-McA
Last active November 13, 2019 19:15
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save N-McA/9bd3a81d3062340e4affaaaaad332107 to your computer and use it in GitHub Desktop.
Save N-McA/9bd3a81d3062340e4affaaaaad332107 to your computer and use it in GitHub Desktop.
Concatenates the (x, y) coordinate normalised to 0-1 to each spatial location in the image. Allows a network to learn spatial bias. Has been explored in at least one paper, "An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution" https://arxiv.org/abs/1807.03247
import keras.backend as kb
from keras.layers import Layer
def _kb_linspace(num):
num = kb.cast(num, kb.floatx())
return kb.arange(0, num, dtype=kb.floatx()) / (num - 1)
def _kb_grid_coords(width, height):
w, h = width, height
coords = kb.stack(
[
kb.reshape(kb.tile(kb.expand_dims(_kb_linspace(num=w), -1), [1, h]), [-1]),
kb.tile(_kb_linspace(num=h), [w]),
],
axis=-1,
)
coords = kb.reshape(coords, [w, h, 2])
return coords
class ConcatSpatialCoordinate(Layer):
def __init__(self, **kwargs):
"""Concatenates the (x, y) coordinate normalised to 0-1 to each spatial location in the image.
Allows a network to learn spatial bias. Has been explored in at least one paper,
"An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution"
https://arxiv.org/abs/1807.03247
Improves performance where spatial bias is appropriate.
Works with dynamic shapes.
# Example
```python
x_input = Input([None, None, 1])
x = ConcatSpatialCoordinate()(x_input)
model = Model(x_input, x)
output = model.predict(np.zeros([1, 3, 3, 1]))
spatial_features = output[0, :, :, -2:]
assert np.all(spatial_features[0, 0] == [0, 0])
assert np.all(spatial_features[-1, -1] == [1, 1])
assert np.all(spatial_features[0, -1] == [0, 1])
# Because this example was 3x3, cordinates are [0, 0.5, 1], so
assert np.all(spatial_features[1, 1] == [0.5, 0.5])
```
"""
if kb.image_data_format() != 'channels_last':
raise Exception((
"Only compatible with"
" kb.image_data_format() == 'channels_last'"))
super(ConcatSpatialCoordinate, self).__init__(**kwargs)
def build(self, input_shape):
super(ConcatSpatialCoordinate, self).build(input_shape)
def call(self, x):
dynamic_input_shape = kb.shape(x)
w = dynamic_input_shape[-3]
h = dynamic_input_shape[-2]
bias = _kb_grid_coords(width=w, height=h)
return kb.concatenate([x, kb.expand_dims(bias, 0)], axis=-1)
def compute_output_shape(self, input_shape):
batch_size, w, h, channels = input_shape
return (batch_size, w, h, channels + 2)
def test_ConcatSpatialCoordinate():
import numpy as np
from keras.layers import Input
from keras.models import Model
x_input = Input([None, None, 1])
x = ConcatSpatialCoordinate()(x_input)
model = Model(x_input, x)
output = model.predict(np.zeros([1, 3, 3, 1]))
spatial_features = output[0, :, :, -2:]
# The following are always true.
assert np.all(spatial_features[0, 0] == [0, 0])
assert np.all(spatial_features[-1, -1] == [1, 1])
assert np.all(spatial_features[0, -1] == [0, 1])
# Because this example was 3x3, cordinates are [0, 0.5, 1], so
assert np.all(spatial_features[1, 1] == [0.5, 0.5])
if __name__ == '__main__':
test_ConcatSpatialCoordinate()
@pangyuteng
Copy link

Thanks for sharing!

For visitors that stopped by this page, below fork contains a minor update to method call to account for the batch size while tiling the coordinate layers to the original input. For example, if the input is 16x128x128x1, the output from ConcatSpatialCoordinate would be 16x128x128x3.
https://gist.github.com/pangyuteng/8f4f7c09b490e1baaef852d07105db77

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment