Skip to content

Instantly share code, notes, and snippets.

Created July 9, 2016 20:29
Show Gist options
  • Save EderSantana/3cebf581aeb2aec896e77e25635994ba to your computer and use it in GitHub Desktop.
Save EderSantana/3cebf581aeb2aec896e77e25635994ba to your computer and use it in GitHub Desktop.
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
from keras.engine import Layer, InputSpec
from keras import backend as K, regularizers, constraints, initializations, activations
class Deconv2D(Layer):
def __init__(self, nb_filter, nb_row, nb_col,
init='glorot_uniform', activation='linear', weights=None,
border_mode='valid', subsample=(1, 1), dim_ordering='tf',
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None, **kwargs):
if border_mode not in {'valid', 'same'}:
raise Exception('Invalid border mode for Convolution2D:', border_mode)
self.nb_filter = nb_filter
self.nb_row = nb_row
self.nb_col = nb_col
self.init = initializations.get(init, dim_ordering=dim_ordering)
self.activation = activations.get(activation)
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
self.border_mode = border_mode
self.subsample = tuple(subsample)
assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
self.dim_ordering = dim_ordering
self.W_regularizer = regularizers.get(W_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.b_constraint = constraints.get(b_constraint)
self.input_spec = [InputSpec(ndim=4)]
self.initial_weights = weights
super(Deconv2D, self).__init__(**kwargs)
def build(self, input_shape):
if self.dim_ordering == 'th':
stack_size = input_shape[1]
self.W_shape = (self.nb_filter, stack_size, self.nb_row, self.nb_col)
elif self.dim_ordering == 'tf':
stack_size = input_shape[3]
self.W_shape = (self.nb_row, self.nb_col, self.nb_filter, stack_size)
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
self.W = self.init(self.W_shape, name='{}/w'.format(
self.b = K.zeros((self.nb_filter,), name='{}/biases'.format(
self.trainable_weights = [self.W, self.b]
self.regularizers = []
if self.W_regularizer:
if self.b_regularizer:
if self.activity_regularizer:
self.constraints = {}
if self.W_constraint:
self.constraints[self.W] = self.W_constraint
if self.b_constraint:
self.constraints[self.b] = self.b_constraint
if self.initial_weights is not None:
del self.initial_weights
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
rows = input_shape[2]
cols = input_shape[3]
elif self.dim_ordering == 'tf':
rows = input_shape[1]
cols = input_shape[2]
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
rows = rows * self.subsample[0]
cols = cols * self.subsample[1]
if self.dim_ordering == 'th':
return (input_shape[0], self.nb_filter, rows, cols)
elif self.dim_ordering == 'tf':
return (input_shape[0], rows, cols, self.nb_filter)
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
def call(self, x, mask=None):
output_shape = self.get_output_shape_for(x.get_shape().as_list())
deconv_out = tf.nn.conv2d_transpose(
x, self.W, output_shape=output_shape, strides=[1, self.subsample[0], self.subsample[1], 1])
if self.dim_ordering == 'th':
output = deconv_out + K.reshape(self.b, (1, self.nb_filter, 1, 1))
elif self.dim_ordering == 'tf':
output = deconv_out + K.reshape(self.b, (1, 1, 1, self.nb_filter))
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
output = self.activation(output)
return output
def get_config(self):
config = {'nb_filter': self.nb_filter,
'nb_row': self.nb_row,
'nb_col': self.nb_col,
'init': self.init.__name__,
'activation': self.activation.__name__,
'border_mode': self.border_mode,
'subsample': self.subsample,
'dim_ordering': self.dim_ordering,
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None}
base_config = super(Deconv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Copy link

MikeAmy commented Jul 10, 2016

@EderSantana In get_output_shape, these lines:

    rows = rows * self.subsample[0]
    cols = cols * self.subsample[1]

I believe that, depending on the border mode, rows and columns can grow/shrink in the reverse way to Convolution2D, I guess this code hasn't taken that into account, right? I can have a look if not.

Copy link

MikeAmy commented Jul 11, 2016

I saw your commit from a few days ago for Deconvolution2D. How does it compare to this? Looks like it's more up-to-date.

Copy link

hi @MikeAmy check this commit keras-team/keras#3251
it should work for both theano and tensorflow

Copy link

MikeAmy commented Jul 24, 2016

Awesome, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment