Skip to content

Instantly share code, notes, and snippets.

@phizaz
Last active December 6, 2017 15:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phizaz/293bed11db89087bfc5afe08f9ee7ede to your computer and use it in GitHub Desktop.
Save phizaz/293bed11db89087bfc5afe08f9ee7ede to your computer and use it in GitHub Desktop.
Keras 2D Conv Layer without Conv2D
class MyConv(Layer):
'''
Implemeting a Conv2D with strides=1, and 'valid' padding
'''
def __init__(self, filters, kernel, **kwargs):
self.filters = filters
self.k_h, self.k_w = kernel
super(MyConv, self).__init__(**kwargs)
def build(self, input_shape):
_, self.h, self.w, self.c = input_shape
# expected output size
self.out_h = self.h - self.k_h + 1
self.out_w = self.w - self.k_w + 1
# allocate vars for kernels
self.kernel_size = self.k_h * self.k_w * self.c
self.kernels = self.add_weight(name='kernel',
shape=[self.k_h, self.k_w,
self.c, self.filters],
initializer='glorot_uniform',
trainable=True)
super(MyConv, self).build(input_shape)
def call(self, x):
# flatten kernels [k_h, k_w, c_in, c_out] -> [k_h * k_w * c_in, c_out]
kernel = K.reshape(self.kernels, [self.kernel_size, self.filters])
t = []
for i in range(self.out_h):
for j in range(self.out_w):
# take a patch
p = x[:, i:i + self.k_h, j:j + self.k_w, :]
# flatten the patch
p = K.reshape(p, [-1, self.kernel_size])
# convolution
conv = K.dot(p, kernel)
# gather tensors
t.append(conv)
# list(tensors) -> big tensor
stacked = K.stack(t, axis=1) # 900 x [?, 3] -> [?, 900, 3] (stacked on axis=1)
print('stacked:', stacked.get_shape()) # [?, 900, 3]
# reshape to 4D tensor [n, h, w, c]
output = K.reshape(stacked, [-1, self.out_h, self.out_w, self.filters])
print('output:', output.get_shape()) # [?, 30, 30, 3]
return output
def compute_output_shape(self, input_shape):
return (None, self.out_h, self.out_w, self.filters)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment