Skip to content

Instantly share code, notes, and snippets.

@dominikandreas
Created July 12, 2018 14:03
Show Gist options
  • Save dominikandreas/2fd56d24bd4f8b594db52f352d5bb862 to your computer and use it in GitHub Desktop.
Save dominikandreas/2fd56d24bd4f8b594db52f352d5bb862 to your computer and use it in GitHub Desktop.
Coordconv as seen in https://arxiv.org/abs/1807.03247
import tensorflow as tf
def add_coord_channels(inputs, with_r=False):
batch_size_tensor = tf.shape(inputs)[0]
x_dim, y_dim = inputs.shape[1].value, inputs.shape[2].value
x_range, y_range = (tf.linspace(-1., 1., d) for d in (x_dim, y_dim))
x_channel = tf.tile(tf.expand_dims(y_range, 0), [x_dim, 1])
y_channel = tf.tile(tf.expand_dims(x_range, 1), [1, y_dim])
x_channel, y_channel = (tf.expand_dims(tf.expand_dims(x, -1), 0) for x in (x_channel, y_channel))
x_channel, y_channel = (tf.tile(x, [batch_size_tensor, 1, 1, 1]) for x in (x_channel, y_channel))
res = tf.concat([inputs, x_channel, y_channel], axis=-1)
if with_r:
rr = tf.sqrt(tf.square(x_channel) + tf.square(y_channel))
res = tf.concat([res, rr], axis=-1)
return res
if __name__ == "__main__":
import numpy as np
from matplotlib import pyplot as plt
img = tf.constant(np.random.normal(size=[2, 32, 32, 1]).astype("float32"))
img_coord = add_coord_channels(img, with_r=True)
# actual conv can be added using e.g. tf.layers.conv2d(img_coord, ...)
sess = tf.Session()
img_c = sess.run(img_coord)[0,:]
channels = img_c.shape[-1]
for i in range(channels):
plt.subplot(channels, 1, i+1).imshow(img_c[:,:,i])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment