Skip to content

Instantly share code, notes, and snippets.

@dswah
Last active September 17, 2021 22:45
Show Gist options
  • Star 26 You must be signed in to star a gist
  • Fork 10 You must be signed in to fork a gist
  • Save dswah/c6b3e4d47d933b057aab32c9c29c4221 to your computer and use it in GitHub Desktop.
Save dswah/c6b3e4d47d933b057aab32c9c29c4221 to your computer and use it in GitHub Desktop.
Tied Convolutional Weights with Keras for CNN Auto-encoders
from keras import backend as K
from keras import activations, initializations, regularizers, constraints
from keras.engine import Layer, InputSpec
from keras.utils.np_utils import conv_output_length
from keras.layers import Convolution1D, Convolution2D
import tensorflow as tf
class Convolution1D_tied(Layer):
'''Convolution operator for filtering neighborhoods of one-dimensional inputs.
When using this layer as the first layer in a model,
either provide the keyword argument `input_dim`
(int, e.g. 128 for sequences of 128-dimensional vectors),
or `input_shape` (tuple of integers, e.g. (10, 128) for sequences
of 10 vectors of 128-dimensional vectors).
# Example
```python
# apply a convolution 1d of length 3 to a sequence with 10 timesteps,
# with 64 output filters
model = Sequential()
model.add(Convolution1D(64, 3, border_mode='same', input_shape=(10, 32)))
# now model.output_shape == (None, 10, 64)
# add a new conv1d on top
model.add(Convolution1D(32, 3, border_mode='same'))
# now model.output_shape == (None, 10, 32)
```
# Arguments
nb_filter: Number of convolution kernels to use
(dimensionality of the output).
filter_length: The extension (spatial or temporal) of each filter.
init: name of initialization function for the weights of the layer
(see [initializations](../initializations.md)),
or alternatively, Theano function to use for weights initialization.
This parameter is only relevant if you don't pass a `weights` argument.
activation: name of activation function to use
(see [activations](../activations.md)),
or alternatively, elementwise Theano function.
If you don't specify anything, no activation is applied
(ie. "linear" activation: a(x) = x).
weights: list of numpy arrays to set as initial weights.
border_mode: 'valid' or 'same'.
subsample_length: factor by which to subsample output.
W_regularizer: instance of [WeightRegularizer](../regularizers.md)
(eg. L1 or L2 regularization), applied to the main weights matrix.
b_regularizer: instance of [WeightRegularizer](../regularizers.md),
applied to the bias.
activity_regularizer: instance of [ActivityRegularizer](../regularizers.md),
applied to the network output.
W_constraint: instance of the [constraints](../constraints.md) module
(eg. maxnorm, nonneg), applied to the main weights matrix.
b_constraint: instance of the [constraints](../constraints.md) module,
applied to the bias.
bias: whether to include a bias
(i.e. make the layer affine rather than linear).
input_dim: Number of channels/dimensions in the input.
Either this argument or the keyword argument `input_shape`must be
provided when using this layer as the first layer in a model.
input_length: Length of input sequences, when it is constant.
This argument is required if you are going to connect
`Flatten` then `Dense` layers upstream
(without it, the shape of the dense outputs cannot be computed).
# Input shape
3D tensor with shape: `(samples, steps, input_dim)`.
# Output shape
3D tensor with shape: `(samples, new_steps, nb_filter)`.
`steps` value might have changed due to padding.
'''
def __init__(self, nb_filter, filter_length,
init='uniform', activation='linear', weights=None,
border_mode='valid', subsample_length=1,
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None,
bias=True, input_dim=None, input_length=None, tied_to=None,
**kwargs):
if border_mode not in {'valid', 'same'}:
raise Exception('Invalid border mode for Convolution1D:', border_mode)
self.tied_to = tied_to
self.nb_filter = nb_filter #TODO may have to change this and the one below...
self.filter_length = tied_to.filter_length
self.init = initializations.get(init, dim_ordering='th')
self.activation = activations.get(activation)
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
self.border_mode = border_mode
self.subsample_length = subsample_length
self.subsample = (subsample_length, 1)
self.W_regularizer = regularizers.get(W_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.b_constraint = constraints.get(b_constraint)
self.bias = bias
self.input_spec = [InputSpec(ndim=3)]
self.initial_weights = tied_to.initial_weights
self.input_dim = input_dim
self.input_length = input_length
if self.input_dim:
kwargs['input_shape'] = (self.input_length, self.input_dim)
super(Convolution1D_tied, self).__init__(**kwargs)
def build(self, input_shape):
# input_dim = input_shape[2]
# self.W_shape = (self.nb_filter, input_dim, self.filter_length, 1)
# self.W = self.init(self.W_shape, name='{}_W'.format(self.name))
if self.bias:
self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name))
self.trainable_weights = [self.b]
# else:
# self.trainable_weights = [self.W]
self.regularizers = []
#
# if self.W_regularizer:
# self.W_regularizer.set_param(self.W)
# self.regularizers.append(self.W_regularizer)
#
if self.bias and self.b_regularizer:
self.b_regularizer.set_param(self.b)
self.regularizers.append(self.b_regularizer)
#
# if self.activity_regularizer:
# self.activity_regularizer.set_layer(self)
# self.regularizers.append(self.activity_regularizer)
#
# self.constraints = {}
# if self.W_constraint:
# self.constraints[self.W] = self.W_constraint
if self.bias and self.b_constraint:
self.constraints[self.b] = self.b_constraint
#
# if self.initial_weights is not None:
# self.set_weights(self.initial_weights)
# del self.initial_weights
def get_output_shape_for(self, input_shape):
length = conv_output_length(input_shape[1],
self.filter_length,
self.border_mode,
self.subsample[0])
return (input_shape[0], length, self.nb_filter)
def call(self, x, mask=None):
x = K.expand_dims(x, -1) # add a dimension of the right
x = K.permute_dimensions(x, (0, 2, 1, 3))
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH kernel shape: (depth, input_depth, rows, cols)
# TF kernel shape: (rows, cols, input_depth, depth)
# for us, we need to switch the rows with the columns?
W = tf.transpose(self.tied_to.W, (1, 0, 2, 3))
output = K.conv2d(x, W, strides=self.subsample,
border_mode=self.border_mode,
dim_ordering='th')
if self.bias:
output += K.reshape(self.b, (1, self.nb_filter, 1, 1))
output = K.squeeze(output, 3) # remove the dummy 3rd dimension
output = K.permute_dimensions(output, (0, 2, 1))
output = self.activation(output)
return output
def get_config(self):
config = {'nb_filter': self.nb_filter,
'filter_length': self.filter_length,
'init': self.init.__name__,
'activation': self.activation.__name__,
'border_mode': self.border_mode,
'subsample_length': self.subsample_length,
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None,
'bias': self.bias,
'input_dim': self.input_dim,
'input_length': self.input_length}
base_config = super(Convolution1D_tied, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Convolution2D_tied(Layer):
'''Convolution operator for filtering windows of two-dimensional inputs.
When using this layer as the first layer in a model,
provide the keyword argument `input_shape`
(tuple of integers, does not include the sample axis),
e.g. `input_shape=(3, 128, 128)` for 128x128 RGB pictures.
# Examples
```python
# apply a 3x3 convolution with 64 output filters on a 256x256 image:
model = Sequential()
model.add(Convolution2D(64, 3, 3, border_mode='same', input_shape=(3, 256, 256)))
# now model.output_shape == (None, 64, 256, 256)
# add a 3x3 convolution on top, with 32 output filters:
model.add(Convolution2D(32, 3, 3, border_mode='same'))
# now model.output_shape == (None, 32, 256, 256)
```
# Arguments
nb_filter: Number of convolution filters to use.
nb_row: Number of rows in the convolution kernel.
nb_col: Number of columns in the convolution kernel.
init: name of initialization function for the weights of the layer
(see [initializations](../initializations.md)), or alternatively,
Theano function to use for weights initialization.
This parameter is only relevant if you don't pass
a `weights` argument.
activation: name of activation function to use
(see [activations](../activations.md)),
or alternatively, elementwise Theano function.
If you don't specify anything, no activation is applied
(ie. "linear" activation: a(x) = x).
weights: list of numpy arrays to set as initial weights.
border_mode: 'valid' or 'same'.
subsample: tuple of length 2. Factor by which to subsample output.
Also called strides elsewhere.
W_regularizer: instance of [WeightRegularizer](../regularizers.md)
(eg. L1 or L2 regularization), applied to the main weights matrix.
b_regularizer: instance of [WeightRegularizer](../regularizers.md),
applied to the bias.
activity_regularizer: instance of [ActivityRegularizer](../regularizers.md),
applied to the network output.
W_constraint: instance of the [constraints](../constraints.md) module
(eg. maxnorm, nonneg), applied to the main weights matrix.
b_constraint: instance of the [constraints](../constraints.md) module,
applied to the bias.
dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension
(the depth) is at index 1, in 'tf' mode is it at index 3.
It defaults to the `image_dim_ordering` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "th".
bias: whether to include a bias
(i.e. make the layer affine rather than linear).
# Input shape
4D tensor with shape:
`(samples, channels, rows, cols)` if dim_ordering='th'
or 4D tensor with shape:
`(samples, rows, cols, channels)` if dim_ordering='tf'.
# Output shape
4D tensor with shape:
`(samples, nb_filter, new_rows, new_cols)` if dim_ordering='th'
or 4D tensor with shape:
`(samples, new_rows, new_cols, nb_filter)` if dim_ordering='tf'.
`rows` and `cols` values might have changed due to padding.
'''
def __init__(self, nb_filter, nb_row, nb_col,
init='glorot_uniform', activation='linear', weights=None,
border_mode='valid', subsample=(1, 1), dim_ordering='default',
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None,
bias=True, tied_to=None, **kwargs):
if dim_ordering == 'default':
dim_ordering = K.image_dim_ordering()
if border_mode not in {'valid', 'same'}:
raise Exception('Invalid border mode for Convolution2D:', border_mode)
self.tied_to = tied_to
self.nb_filter = nb_filter
self.nb_row = tied_to.nb_row
self.nb_col = tied_to.nb_col
self.init = initializations.get(init, dim_ordering=dim_ordering)
self.activation = activations.get(activation)
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}'
self.border_mode = border_mode
self.subsample = tuple(subsample)
assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
self.dim_ordering = dim_ordering
self.W_regularizer = regularizers.get(W_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.b_constraint = constraints.get(b_constraint)
self.bias = bias
self.input_spec = [InputSpec(ndim=4)]
self.initial_weights = tied_to.initial_weights
super(Convolution2D_tied, self).__init__(**kwargs)
def build(self, input_shape):
if self.dim_ordering == 'th':
stack_size = input_shape[1]
self.W_shape = (self.nb_filter, stack_size, self.nb_row, self.nb_col)
elif self.dim_ordering == 'tf':
stack_size = input_shape[3]
self.W_shape = (self.nb_row, self.nb_col, stack_size, self.nb_filter)
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
# self.W = self.init(self.W_shape, name='{}_W'.format(self.name))
if self.bias:
self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name))
self.trainable_weights = [self.b]
# else:
# self.trainable_weights = [self.W]
self.regularizers = []
# if self.W_regularizer:
# self.W_regularizer.set_param(self.W)
# self.regularizers.append(self.W_regularizer)
if self.bias and self.b_regularizer:
self.b_regularizer.set_param(self.b)
self.regularizers.append(self.b_regularizer)
if self.activity_regularizer:
self.activity_regularizer.set_layer(self)
self.regularizers.append(self.activity_regularizer)
self.constraints = {}
# if self.W_constraint:
# self.constraints[self.W] = self.W_constraint
if self.bias and self.b_constraint:
self.constraints[self.b] = self.b_constraint
# if self.initial_weights is not None:
# self.set_weights(self.initial_weights)
# del self.initial_weights
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
rows = input_shape[2]
cols = input_shape[3]
elif self.dim_ordering == 'tf':
rows = input_shape[1]
cols = input_shape[2]
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
rows = conv_output_length(rows, self.nb_row,
self.border_mode, self.subsample[0])
cols = conv_output_length(cols, self.nb_col,
self.border_mode, self.subsample[1])
if self.dim_ordering == 'th':
return (input_shape[0], self.nb_filter, rows, cols)
elif self.dim_ordering == 'tf':
return (input_shape[0], rows, cols, self.nb_filter)
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
def call(self, x, mask=None):
W = tf.transpose(self.tied_to.W, (1, 0, 2, 3))
output = K.conv2d(x, W, strides=self.subsample,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering,
filter_shape=self.W_shape)
if self.bias:
if self.dim_ordering == 'th':
output += K.reshape(self.b, (1, self.nb_filter, 1, 1))
elif self.dim_ordering == 'tf':
output += K.reshape(self.b, (1, 1, 1, self.nb_filter))
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
output = self.activation(output)
return output
def get_config(self):
config = {'nb_filter': self.nb_filter,
'nb_row': self.nb_row,
'nb_col': self.nb_col,
'init': self.init.__name__,
'activation': self.activation.__name__,
'border_mode': self.border_mode,
'subsample': self.subsample,
'dim_ordering': self.dim_ordering,
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None,
'bias': self.bias}
base_config = super(Convolution2D_tied, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@dswah
Copy link
Author

dswah commented Sep 2, 2016

To create Convolutional Auto-encoders with tied weights, first instantiate your layer as usual, but then pass that instance to the layer from the special class above:

input_img = Input(shape=(1, 28, 28))

conv1 = Convolution2D(16, 3, 3, activation='relu', border_mode='same') # create the layer instance that i want to tie to
c1 = conv1(input_img) # then call the layer
m1 = MaxPooling2D((2, 2), border_mode='same')(c1)

c2 = Convolution2D(16, 3, 3, activation='relu', border_mode='same')(m1)
u1 = UpSampling2D((2, 2))(c2)
d1 = Convolution2D_tied(1, 3, 3, activation='sigmoid', border_mode='same', tied_to=conv1)(u1) # now this layer is tied to conv1

# and compile as usual
autoencoder = Model(input_img, d1)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

Since the weights are tied, you will see fewer trainable parameters:

autoencoder.summary()
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_17 (InputLayer)            (None, 1, 28, 28)     0                                            
____________________________________________________________________________________________________
convolution2d_82 (Convolution2D) (None, 16, 28, 28)    160         input_17[0][0]                   
____________________________________________________________________________________________________
maxpooling2d_47 (MaxPooling2D)   (None, 16, 14, 14)    0           convolution2d_82[0][0]           
____________________________________________________________________________________________________
convolution2d_83 (Convolution2D) (None, 16, 14, 14)    2320        maxpooling2d_47[0][0]            
____________________________________________________________________________________________________
upsampling2d_32 (UpSampling2D)   (None, 16, 28, 28)    0           convolution2d_83[0][0]           
____________________________________________________________________________________________________
convolution2d_tied_9 (Convolution(None, 1, 28, 28)     1           upsampling2d_32[0][0]            
====================================================================================================
Total params: 2481
____________________________________________________________________________________________________

The only trainable parameters in the tied layers are the biases.

@dswah
Copy link
Author

dswah commented Sep 2, 2016

Lightly tested in TensorFlow, but not Theano :)

@keunwoochoi
Copy link

Nice! Thanks a lot.

@keunwoochoi
Copy link

keunwoochoi commented Nov 24, 2016

For theano users, tf.transpose() can be replaced with K.transpose(), or probably with K.permute_dimensions(x, patterns), since the code is..

W = tf.transpose(self.tied_to.W, (1, 0, 2, 3))

.

@benjertho
Copy link

benjertho commented Feb 4, 2017

Thanks for the layer!

I think line 346 has an error though:

W = tf.transpose(self.tied_to.W, (1, 0, 2, 3))

Should actually be

W = tf.transpose(self.tied_to.W, (1, 0, 3, 2))

It becomes apparent in the cases where there are different numbers of hidden layers.

@dswah
Copy link
Author

dswah commented Feb 13, 2017

Cool catch. let me take a look

@vbvg2008
Copy link

So what should it be for the transpose permutation?

(1,0,3,2) or (1,0,2,3) ?

also, this code is not compatible with the new version of keras, can someone update it?

@Boialex
Copy link

Boialex commented Dec 19, 2017

Hi! Thanks for the layer!

Have you tested loading the model with such a layer? I tried to write Dense_tied in the same manner, but when loading with keras.load_model, it tried to call call before self.tied_to layer was created and raised AttributeError: 'NoneType' object has no attribute 'kernel' (same with W here)

@isaacgerg
Copy link

Has this been added to keras-contrib?

@viktor-ferenczi
Copy link

Is this available built-in or as a contrib in latest Keras / TensorFlow? I cannot seem to find it, while it would be highly useful for autoencoders.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment