Skip to content

Instantly share code, notes, and snippets.

@valtron
Last active February 19, 2017 20:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save valtron/8f09f51bac33fcfe824e3a8dd7682009 to your computer and use it in GitHub Desktop.
Save valtron/8f09f51bac33fcfe824e3a8dd7682009 to your computer and use it in GitHub Desktop.
"""
Filter sharing convolution: https://arxiv.org/abs/1612.02575
Excl. bias, has MNP+PS parameters (vs. MNS for regular conv).
"""
import functools
from keras import backend as K
from keras import activations, initializations, regularizers, constraints
from keras.engine import Layer, InputSpec
from keras.utils.np_utils import conv_output_length, conv_input_length
class SharedFilterConvolution2D(Layer):
def __init__(
self, nb_filter, nb_seed_filter, kernel_dim1, kernel_dim2,
init='glorot_uniform', activation=None, weights=None,
border_mode='valid', subsample=(1, 1), dim_ordering='default',
S_regularizer=None, A_regularizer=None, b_regularizer=None, activity_regularizer=None,
S_constraint=None, A_constraint=None, b_constraint=None,
bias=True, **kwargs):
if dim_ordering == 'default':
dim_ordering = K.image_dim_ordering()
if border_mode not in {'valid', 'same', 'full'}:
raise ValueError('Invalid border mode for SharedFilterConvolution2D:', border_mode)
self.nb_filter = nb_filter
self.nb_seed_filter = nb_seed_filter
self.kernel_dim1 = kernel_dim1
self.kernel_dim2 = kernel_dim2
self.init = initializations.get(init)
self.activation = activations.get(activation)
self.border_mode = border_mode
self.subsample = tuple(subsample)
if dim_ordering not in {'tf', 'th'}:
raise ValueError('dim_ordering must be in {tf, th}.')
self.dim_ordering = dim_ordering
self.S_regularizer = regularizers.get(S_regularizer)
self.A_regularizer = regularizers.get(A_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.S_constraint = constraints.get(S_constraint)
self.A_constraint = constraints.get(A_constraint)
self.b_constraint = constraints.get(b_constraint)
self.bias = bias
self.input_spec = [InputSpec(ndim=4)]
self.initial_weights = weights
super(SharedFilterConvolution2D, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 4
if self.dim_ordering == 'th':
stack_size = input_shape[1]
self.W_shape = (self.nb_filter, stack_size, self.kernel_dim1, self.kernel_dim2)
self.S_shape = (self.kernel_dim1, self.nb_seed_filter, self.kernel_dim2)
self.A_shape = (self.nb_filter, stack_size, self.nb_seed_filter)
elif self.dim_ordering == 'tf':
stack_size = input_shape[4]
self.W_shape = (self.kernel_dim1, self.kernel_dim2, stack_size, self.nb_filter)
self.S_shape = (self.kernel_dim1, self.kernel_dim2, self.nb_seed_filter)
self.A_shape = (stack_size, self.nb_seed_filter, self.nb_filter)
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
self.S = self.add_weight(self.S_shape,
initializer=functools.partial(self.init, dim_ordering=self.dim_ordering),
name='{}_S'.format(self.name),
regularizer=self.S_regularizer,
constraint=self.S_constraint
)
self.A = self.add_weight(self.A_shape,
initializer = functools.partial(self.init),
name = '{}_A'.format(self.name),
regularizer=self.A_regularizer,
constraint=self.A_constraint
)
if self.bias:
self.b = self.add_weight(
(self.nb_filter,), initializer='zero',
name='{}_b'.format(self.name),
regularizer=self.b_regularizer,
constraint=self.b_constraint
)
else:
self.b = None
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
conv_dim1 = input_shape[2]
conv_dim2 = input_shape[3]
elif self.dim_ordering == 'tf':
conv_dim1 = input_shape[1]
conv_dim2 = input_shape[2]
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
conv_dim1 = conv_output_length(conv_dim1, self.kernel_dim1, self.border_mode, self.subsample[0])
conv_dim2 = conv_output_length(conv_dim2, self.kernel_dim2, self.border_mode, self.subsample[1])
if self.dim_ordering == 'th':
return (input_shape[0], self.nb_filter, conv_dim1, conv_dim2)
elif self.dim_ordering == 'tf':
return (input_shape[0], conv_dim1, conv_dim2, self.nb_filter)
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
def call(self, x, mask=None):
if self.dim_ordering == 'th':
W = K.dot(self.A, self.S)
elif self.dim_ordering == 'tf':
W = K.dot(self.S, self.A)
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
output = K.conv2d(
x, W, strides=self.subsample,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering,
filter_shape=self.W_shape
)
if self.bias:
if self.dim_ordering == 'th':
output += K.reshape(self.b, (1, self.nb_filter, 1, 1))
elif self.dim_ordering == 'tf':
output += K.reshape(self.b, (1, 1, 1, self.nb_filter))
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
output = self.activation(output)
return output
def get_config(self):
config = {
'nb_filter': self.nb_filter,
'nb_seed_filter': self.nb_seed_filter,
'kernel_dim1': self.kernel_dim1,
'kernel_dim2': self.kernel_dim2,
'dim_ordering': self.dim_ordering,
'init': self.init.__name__,
'activation': self.activation.__name__,
'border_mode': self.border_mode,
'subsample': self.subsample,
'S_regularizer': self.S_regularizer.get_config() if self.S_regularizer else None,
'A_regularizer': self.A_regularizer.get_config() if self.A_regularizer else None,
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'S_constraint': self.S_constraint.get_config() if self.S_constraint else None,
'A_constraint': self.A_constraint.get_config() if self.A_constraint else None,
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None,
'bias': self.bias
}
base_config = super(SharedFilterConvolution2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class SharedFilterConvolution3D(Layer):
def __init__(
self, nb_filter, nb_seed_filter, kernel_dim1, kernel_dim2, kernel_dim3,
init='glorot_uniform', activation=None, weights=None,
border_mode='valid', subsample=(1, 1, 1), dim_ordering='default',
S_regularizer=None, A_regularizer=None, b_regularizer=None, activity_regularizer=None,
S_constraint=None, A_constraint=None, b_constraint=None,
bias=True, **kwargs):
if dim_ordering == 'default':
dim_ordering = K.image_dim_ordering()
if border_mode not in {'valid', 'same', 'full'}:
raise ValueError('Invalid border mode for SharedFilterConvolution3D:', border_mode)
self.nb_filter = nb_filter
self.nb_seed_filter = nb_seed_filter
self.kernel_dim1 = kernel_dim1
self.kernel_dim2 = kernel_dim2
self.kernel_dim3 = kernel_dim3
self.init = initializations.get(init)
self.activation = activations.get(activation)
self.border_mode = border_mode
self.subsample = tuple(subsample)
if dim_ordering not in {'tf', 'th'}:
raise ValueError('dim_ordering must be in {tf, th}.')
self.dim_ordering = dim_ordering
self.S_regularizer = regularizers.get(S_regularizer)
self.A_regularizer = regularizers.get(A_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.S_constraint = constraints.get(S_constraint)
self.A_constraint = constraints.get(A_constraint)
self.b_constraint = constraints.get(b_constraint)
self.bias = bias
self.input_spec = [InputSpec(ndim=5)]
self.initial_weights = weights
super(SharedFilterConvolution3D, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 5
if self.dim_ordering == 'th':
stack_size = input_shape[1]
self.W_shape = (self.nb_filter, stack_size, self.kernel_dim1, self.kernel_dim2, self.kernel_dim3)
self.S_shape = (self.kernel_dim1, self.kernel_dim2, self.nb_seed_filter, self.kernel_dim3)
self.A_shape = (self.nb_filter, stack_size, self.nb_seed_filter)
elif self.dim_ordering == 'tf':
stack_size = input_shape[4]
self.W_shape = (self.kernel_dim1, self.kernel_dim2, self.kernel_dim3, stack_size, self.nb_filter)
self.S_shape = (self.kernel_dim1, self.kernel_dim2, self.kernel_dim3, self.nb_seed_filter)
self.A_shape = (stack_size, self.nb_seed_filter, self.nb_filter)
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
self.S = self.add_weight(
self.S_shape, initializer=functools.partial(self.init, dim_ordering=self.dim_ordering),
name='{}_S'.format(self.name),
regularizer=self.S_regularizer,
constraint=self.S_constraint
)
self.A = self.add_weight(self.A_shape,
initializer = functools.partial(self.init),
name = '{}_A'.format(self.name),
regularizer=self.A_regularizer,
constraint=self.A_constraint
)
if self.bias:
self.b = self.add_weight(
(self.nb_filter,), initializer='zero',
name='{}_b'.format(self.name),
regularizer=self.b_regularizer,
constraint=self.b_constraint
)
else:
self.b = None
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
conv_dim1 = input_shape[2]
conv_dim2 = input_shape[3]
conv_dim3 = input_shape[4]
elif self.dim_ordering == 'tf':
conv_dim1 = input_shape[1]
conv_dim2 = input_shape[2]
conv_dim3 = input_shape[3]
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
conv_dim1 = conv_output_length(conv_dim1, self.kernel_dim1, self.border_mode, self.subsample[0])
conv_dim2 = conv_output_length(conv_dim2, self.kernel_dim2, self.border_mode, self.subsample[1])
conv_dim3 = conv_output_length(conv_dim3, self.kernel_dim3, self.border_mode, self.subsample[2])
if self.dim_ordering == 'th':
return (input_shape[0], self.nb_filter, conv_dim1, conv_dim2, conv_dim3)
elif self.dim_ordering == 'tf':
return (input_shape[0], conv_dim1, conv_dim2, conv_dim3, self.nb_filter)
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
def call(self, x, mask=None):
if self.dim_ordering == 'th':
W = K.dot(self.A, self.S)
elif self.dim_ordering == 'tf':
W = K.dot(self.S, self.A)
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
output = K.conv3d(
x, W, strides=self.subsample,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering,
filter_shape=self.W_shape
)
if self.bias:
if self.dim_ordering == 'th':
output += K.reshape(self.b, (1, self.nb_filter, 1, 1, 1))
elif self.dim_ordering == 'tf':
output += K.reshape(self.b, (1, 1, 1, 1, self.nb_filter))
else:
raise ValueError('Invalid dim_ordering:', self.dim_ordering)
output = self.activation(output)
return output
def get_config(self):
config = {
'nb_filter': self.nb_filter,
'nb_seed_filter': self.nb_seed_filter,
'kernel_dim1': self.kernel_dim1,
'kernel_dim2': self.kernel_dim2,
'kernel_dim3': self.kernel_dim3,
'dim_ordering': self.dim_ordering,
'init': self.init.__name__,
'activation': self.activation.__name__,
'border_mode': self.border_mode,
'subsample': self.subsample,
'S_regularizer': self.S_regularizer.get_config() if self.S_regularizer else None,
'A_regularizer': self.A_regularizer.get_config() if self.A_regularizer else None,
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'S_constraint': self.S_constraint.get_config() if self.S_constraint else None,
'A_constraint': self.A_constraint.get_config() if self.A_constraint else None,
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None,
'bias': self.bias
}
base_config = super(SharedFilterConvolution3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment