Skip to content

Instantly share code, notes, and snippets.

@EderSantana
Created July 9, 2016 20:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EderSantana/3cebf581aeb2aec896e77e25635994ba to your computer and use it in GitHub Desktop.
Save EderSantana/3cebf581aeb2aec896e77e25635994ba to your computer and use it in GitHub Desktop.
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
from keras.engine import Layer, InputSpec
from keras import backend as K, regularizers, constraints, initializations, activations
class Deconv2D(Layer):
def __init__(self, nb_filter, nb_row, nb_col,
init='glorot_uniform', activation='linear', weights=None,
border_mode='valid', subsample=(1, 1), dim_ordering='tf',
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None, **kwargs):
if border_mode not in {'valid', 'same'}:
raise Exception('Invalid border mode for Convolution2D:', border_mode)
self.nb_filter = nb_filter
self.nb_row = nb_row
self.nb_col = nb_col
self.init = initializations.get(init, dim_ordering=dim_ordering)
self.activation = activations.get(activation)
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
self.border_mode = border_mode
self.subsample = tuple(subsample)
assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
self.dim_ordering = dim_ordering
self.W_regularizer = regularizers.get(W_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.b_constraint = constraints.get(b_constraint)
self.input_spec = [InputSpec(ndim=4)]
self.initial_weights = weights
super(Deconv2D, self).__init__(**kwargs)
def build(self, input_shape):
if self.dim_ordering == 'th':
stack_size = input_shape[1]
self.W_shape = (self.nb_filter, stack_size, self.nb_row, self.nb_col)
elif self.dim_ordering == 'tf':
stack_size = input_shape[3]
self.W_shape = (self.nb_row, self.nb_col, self.nb_filter, stack_size)
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
self.W = self.init(self.W_shape, name='{}/w'.format(self.name))
self.b = K.zeros((self.nb_filter,), name='{}/biases'.format(self.name))
self.trainable_weights = [self.W, self.b]
self.regularizers = []
if self.W_regularizer:
self.W_regularizer.set_param(self.W)
self.regularizers.append(self.W_regularizer)
if self.b_regularizer:
self.b_regularizer.set_param(self.b)
self.regularizers.append(self.b_regularizer)
if self.activity_regularizer:
self.activity_regularizer.set_layer(self)
self.regularizers.append(self.activity_regularizer)
self.constraints = {}
if self.W_constraint:
self.constraints[self.W] = self.W_constraint
if self.b_constraint:
self.constraints[self.b] = self.b_constraint
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
rows = input_shape[2]
cols = input_shape[3]
elif self.dim_ordering == 'tf':
rows = input_shape[1]
cols = input_shape[2]
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
rows = rows * self.subsample[0]
cols = cols * self.subsample[1]
if self.dim_ordering == 'th':
return (input_shape[0], self.nb_filter, rows, cols)
elif self.dim_ordering == 'tf':
return (input_shape[0], rows, cols, self.nb_filter)
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
def call(self, x, mask=None):
output_shape = self.get_output_shape_for(x.get_shape().as_list())
deconv_out = tf.nn.conv2d_transpose(
x, self.W, output_shape=output_shape, strides=[1, self.subsample[0], self.subsample[1], 1])
if self.dim_ordering == 'th':
output = deconv_out + K.reshape(self.b, (1, self.nb_filter, 1, 1))
elif self.dim_ordering == 'tf':
output = deconv_out + K.reshape(self.b, (1, 1, 1, self.nb_filter))
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
output = self.activation(output)
return output
def get_config(self):
config = {'nb_filter': self.nb_filter,
'nb_row': self.nb_row,
'nb_col': self.nb_col,
'init': self.init.__name__,
'activation': self.activation.__name__,
'border_mode': self.border_mode,
'subsample': self.subsample,
'dim_ordering': self.dim_ordering,
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None}
base_config = super(Deconv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@MikeAmy
Copy link

MikeAmy commented Jul 10, 2016

@EderSantana In get_output_shape, these lines:

    rows = rows * self.subsample[0]
    cols = cols * self.subsample[1]

I believe that, depending on the border mode, rows and columns can grow/shrink in the reverse way to Convolution2D, I guess this code hasn't taken that into account, right? I can have a look if not.

@MikeAmy
Copy link

MikeAmy commented Jul 11, 2016

@EderSantana
I saw your commit https://github.com/fchollet/keras/pull/3133/files from a few days ago for Deconvolution2D. How does it compare to this? Looks like it's more up-to-date.

@EderSantana
Copy link
Author

hi @MikeAmy check this commit keras-team/keras#3251
it should work for both theano and tensorflow

@MikeAmy
Copy link

MikeAmy commented Jul 24, 2016

Awesome, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment