Skip to content

Instantly share code, notes, and snippets.

@Dref360
Last active February 11, 2020 14:40
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save Dref360/b330e75cb121c03a0066d9587a7bfee5 to your computer and use it in GitHub Desktop.
Save Dref360/b330e75cb121c03a0066d9587a7bfee5 to your computer and use it in GitHub Desktop.
Un-scaled version of CoordConv2D
import keras.backend as K
import tensorflow as tf
from tensorflow.keras.layers import Layer
"""Not tested, I'll play around with GANs soon with it."""
from tensorflow.keras.layers import Conv2D
import numpy as np
class CoordConv2D(Layer):
def __init__(self, channel, kernel, padding='valid', **kwargs):
self.layer = Conv2D(channel, kernel, padding=padding)
super(CoordConv2D, self).__init__(**kwargs)
def call(self, inputs, **kwargs):
indices = tf.ones_like(inputs)
if K.image_data_format() == 'channel_first':
# bs, channel, w, h
indices = indices[:, 0, ...]
else:
#bs, w, h, channel = input_shape
indices = indices[..., 0]
# Get indices
bs, w, h = [tf.shape(indices)[k] for k in range(3)]
indices = K.cast(tf.where(indices), tf.float32)
canvas = K.reshape(indices, [bs, w, h, 3])[..., 1:]
# Normalize the canvas
canvas = canvas / tf.cast(K.reshape([w, h], [1, 1, 1, 2]), tf.float32)
canvas = (canvas * 2) - 1
# If channel_first, we swap
if K.image_data_format() == 'channel_first':
canvas = K.permute_dimensions(canvas, [0, 3, 1, 2])
# Concatenate channel-wise
input = K.concatenate([inputs, canvas], -1)
return self.layer(input)
def compute_output_shape(self, input_shape):
return self.layer.compute_output_shape(input_shape)
class CustomModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.l = CoordConv2D(63, 3, padding='same')
def call(self, inputs):
x = self.l(inputs)
return x
def main():
mod = CustomModel()
mod.compile('sgd', 'mse')
mod.run_eagerly = True
res = mod.predict(np.ones([3, 32, 32, 3]))
print(res.shape)
if __name__ == '__main__':
main()
@thomasaarholt
Copy link

@Dref360, I've tried modifying the above for Tensorflow 2.1. (see my fork). Any chance you understand what is causing the error that results when one runs it?

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: model/coord_conv2d/unstack:0

@Dref360
Copy link
Author

Dref360 commented Feb 11, 2020

Hi Thomas,
I updated the gist to TF2.
Thank you!

@thomasaarholt
Copy link

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment