Last active
February 19, 2017 20:18
-
-
Save valtron/8f09f51bac33fcfe824e3a8dd7682009 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Filter sharing convolution: https://arxiv.org/abs/1612.02575 | |
Excl. bias, has MNP+PS parameters (vs. MNS for regular conv). | |
""" | |
import functools | |
from keras import backend as K | |
from keras import activations, initializations, regularizers, constraints | |
from keras.engine import Layer, InputSpec | |
from keras.utils.np_utils import conv_output_length, conv_input_length | |
class SharedFilterConvolution2D(Layer): | |
def __init__( | |
self, nb_filter, nb_seed_filter, kernel_dim1, kernel_dim2, | |
init='glorot_uniform', activation=None, weights=None, | |
border_mode='valid', subsample=(1, 1), dim_ordering='default', | |
S_regularizer=None, A_regularizer=None, b_regularizer=None, activity_regularizer=None, | |
S_constraint=None, A_constraint=None, b_constraint=None, | |
bias=True, **kwargs): | |
if dim_ordering == 'default': | |
dim_ordering = K.image_dim_ordering() | |
if border_mode not in {'valid', 'same', 'full'}: | |
raise ValueError('Invalid border mode for SharedFilterConvolution2D:', border_mode) | |
self.nb_filter = nb_filter | |
self.nb_seed_filter = nb_seed_filter | |
self.kernel_dim1 = kernel_dim1 | |
self.kernel_dim2 = kernel_dim2 | |
self.init = initializations.get(init) | |
self.activation = activations.get(activation) | |
self.border_mode = border_mode | |
self.subsample = tuple(subsample) | |
if dim_ordering not in {'tf', 'th'}: | |
raise ValueError('dim_ordering must be in {tf, th}.') | |
self.dim_ordering = dim_ordering | |
self.S_regularizer = regularizers.get(S_regularizer) | |
self.A_regularizer = regularizers.get(A_regularizer) | |
self.b_regularizer = regularizers.get(b_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.S_constraint = constraints.get(S_constraint) | |
self.A_constraint = constraints.get(A_constraint) | |
self.b_constraint = constraints.get(b_constraint) | |
self.bias = bias | |
self.input_spec = [InputSpec(ndim=4)] | |
self.initial_weights = weights | |
super(SharedFilterConvolution2D, self).__init__(**kwargs) | |
def build(self, input_shape): | |
assert len(input_shape) == 4 | |
if self.dim_ordering == 'th': | |
stack_size = input_shape[1] | |
self.W_shape = (self.nb_filter, stack_size, self.kernel_dim1, self.kernel_dim2) | |
self.S_shape = (self.kernel_dim1, self.nb_seed_filter, self.kernel_dim2) | |
self.A_shape = (self.nb_filter, stack_size, self.nb_seed_filter) | |
elif self.dim_ordering == 'tf': | |
stack_size = input_shape[4] | |
self.W_shape = (self.kernel_dim1, self.kernel_dim2, stack_size, self.nb_filter) | |
self.S_shape = (self.kernel_dim1, self.kernel_dim2, self.nb_seed_filter) | |
self.A_shape = (stack_size, self.nb_seed_filter, self.nb_filter) | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
self.S = self.add_weight(self.S_shape, | |
initializer=functools.partial(self.init, dim_ordering=self.dim_ordering), | |
name='{}_S'.format(self.name), | |
regularizer=self.S_regularizer, | |
constraint=self.S_constraint | |
) | |
self.A = self.add_weight(self.A_shape, | |
initializer = functools.partial(self.init), | |
name = '{}_A'.format(self.name), | |
regularizer=self.A_regularizer, | |
constraint=self.A_constraint | |
) | |
if self.bias: | |
self.b = self.add_weight( | |
(self.nb_filter,), initializer='zero', | |
name='{}_b'.format(self.name), | |
regularizer=self.b_regularizer, | |
constraint=self.b_constraint | |
) | |
else: | |
self.b = None | |
if self.initial_weights is not None: | |
self.set_weights(self.initial_weights) | |
del self.initial_weights | |
self.built = True | |
def get_output_shape_for(self, input_shape): | |
if self.dim_ordering == 'th': | |
conv_dim1 = input_shape[2] | |
conv_dim2 = input_shape[3] | |
elif self.dim_ordering == 'tf': | |
conv_dim1 = input_shape[1] | |
conv_dim2 = input_shape[2] | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
conv_dim1 = conv_output_length(conv_dim1, self.kernel_dim1, self.border_mode, self.subsample[0]) | |
conv_dim2 = conv_output_length(conv_dim2, self.kernel_dim2, self.border_mode, self.subsample[1]) | |
if self.dim_ordering == 'th': | |
return (input_shape[0], self.nb_filter, conv_dim1, conv_dim2) | |
elif self.dim_ordering == 'tf': | |
return (input_shape[0], conv_dim1, conv_dim2, self.nb_filter) | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
def call(self, x, mask=None): | |
if self.dim_ordering == 'th': | |
W = K.dot(self.A, self.S) | |
elif self.dim_ordering == 'tf': | |
W = K.dot(self.S, self.A) | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
output = K.conv2d( | |
x, W, strides=self.subsample, | |
border_mode=self.border_mode, | |
dim_ordering=self.dim_ordering, | |
filter_shape=self.W_shape | |
) | |
if self.bias: | |
if self.dim_ordering == 'th': | |
output += K.reshape(self.b, (1, self.nb_filter, 1, 1)) | |
elif self.dim_ordering == 'tf': | |
output += K.reshape(self.b, (1, 1, 1, self.nb_filter)) | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
output = self.activation(output) | |
return output | |
def get_config(self): | |
config = { | |
'nb_filter': self.nb_filter, | |
'nb_seed_filter': self.nb_seed_filter, | |
'kernel_dim1': self.kernel_dim1, | |
'kernel_dim2': self.kernel_dim2, | |
'dim_ordering': self.dim_ordering, | |
'init': self.init.__name__, | |
'activation': self.activation.__name__, | |
'border_mode': self.border_mode, | |
'subsample': self.subsample, | |
'S_regularizer': self.S_regularizer.get_config() if self.S_regularizer else None, | |
'A_regularizer': self.A_regularizer.get_config() if self.A_regularizer else None, | |
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None, | |
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, | |
'S_constraint': self.S_constraint.get_config() if self.S_constraint else None, | |
'A_constraint': self.A_constraint.get_config() if self.A_constraint else None, | |
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, | |
'bias': self.bias | |
} | |
base_config = super(SharedFilterConvolution2D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class SharedFilterConvolution3D(Layer): | |
def __init__( | |
self, nb_filter, nb_seed_filter, kernel_dim1, kernel_dim2, kernel_dim3, | |
init='glorot_uniform', activation=None, weights=None, | |
border_mode='valid', subsample=(1, 1, 1), dim_ordering='default', | |
S_regularizer=None, A_regularizer=None, b_regularizer=None, activity_regularizer=None, | |
S_constraint=None, A_constraint=None, b_constraint=None, | |
bias=True, **kwargs): | |
if dim_ordering == 'default': | |
dim_ordering = K.image_dim_ordering() | |
if border_mode not in {'valid', 'same', 'full'}: | |
raise ValueError('Invalid border mode for SharedFilterConvolution3D:', border_mode) | |
self.nb_filter = nb_filter | |
self.nb_seed_filter = nb_seed_filter | |
self.kernel_dim1 = kernel_dim1 | |
self.kernel_dim2 = kernel_dim2 | |
self.kernel_dim3 = kernel_dim3 | |
self.init = initializations.get(init) | |
self.activation = activations.get(activation) | |
self.border_mode = border_mode | |
self.subsample = tuple(subsample) | |
if dim_ordering not in {'tf', 'th'}: | |
raise ValueError('dim_ordering must be in {tf, th}.') | |
self.dim_ordering = dim_ordering | |
self.S_regularizer = regularizers.get(S_regularizer) | |
self.A_regularizer = regularizers.get(A_regularizer) | |
self.b_regularizer = regularizers.get(b_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.S_constraint = constraints.get(S_constraint) | |
self.A_constraint = constraints.get(A_constraint) | |
self.b_constraint = constraints.get(b_constraint) | |
self.bias = bias | |
self.input_spec = [InputSpec(ndim=5)] | |
self.initial_weights = weights | |
super(SharedFilterConvolution3D, self).__init__(**kwargs) | |
def build(self, input_shape): | |
assert len(input_shape) == 5 | |
if self.dim_ordering == 'th': | |
stack_size = input_shape[1] | |
self.W_shape = (self.nb_filter, stack_size, self.kernel_dim1, self.kernel_dim2, self.kernel_dim3) | |
self.S_shape = (self.kernel_dim1, self.kernel_dim2, self.nb_seed_filter, self.kernel_dim3) | |
self.A_shape = (self.nb_filter, stack_size, self.nb_seed_filter) | |
elif self.dim_ordering == 'tf': | |
stack_size = input_shape[4] | |
self.W_shape = (self.kernel_dim1, self.kernel_dim2, self.kernel_dim3, stack_size, self.nb_filter) | |
self.S_shape = (self.kernel_dim1, self.kernel_dim2, self.kernel_dim3, self.nb_seed_filter) | |
self.A_shape = (stack_size, self.nb_seed_filter, self.nb_filter) | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
self.S = self.add_weight( | |
self.S_shape, initializer=functools.partial(self.init, dim_ordering=self.dim_ordering), | |
name='{}_S'.format(self.name), | |
regularizer=self.S_regularizer, | |
constraint=self.S_constraint | |
) | |
self.A = self.add_weight(self.A_shape, | |
initializer = functools.partial(self.init), | |
name = '{}_A'.format(self.name), | |
regularizer=self.A_regularizer, | |
constraint=self.A_constraint | |
) | |
if self.bias: | |
self.b = self.add_weight( | |
(self.nb_filter,), initializer='zero', | |
name='{}_b'.format(self.name), | |
regularizer=self.b_regularizer, | |
constraint=self.b_constraint | |
) | |
else: | |
self.b = None | |
if self.initial_weights is not None: | |
self.set_weights(self.initial_weights) | |
del self.initial_weights | |
self.built = True | |
def get_output_shape_for(self, input_shape): | |
if self.dim_ordering == 'th': | |
conv_dim1 = input_shape[2] | |
conv_dim2 = input_shape[3] | |
conv_dim3 = input_shape[4] | |
elif self.dim_ordering == 'tf': | |
conv_dim1 = input_shape[1] | |
conv_dim2 = input_shape[2] | |
conv_dim3 = input_shape[3] | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
conv_dim1 = conv_output_length(conv_dim1, self.kernel_dim1, self.border_mode, self.subsample[0]) | |
conv_dim2 = conv_output_length(conv_dim2, self.kernel_dim2, self.border_mode, self.subsample[1]) | |
conv_dim3 = conv_output_length(conv_dim3, self.kernel_dim3, self.border_mode, self.subsample[2]) | |
if self.dim_ordering == 'th': | |
return (input_shape[0], self.nb_filter, conv_dim1, conv_dim2, conv_dim3) | |
elif self.dim_ordering == 'tf': | |
return (input_shape[0], conv_dim1, conv_dim2, conv_dim3, self.nb_filter) | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
def call(self, x, mask=None): | |
if self.dim_ordering == 'th': | |
W = K.dot(self.A, self.S) | |
elif self.dim_ordering == 'tf': | |
W = K.dot(self.S, self.A) | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
output = K.conv3d( | |
x, W, strides=self.subsample, | |
border_mode=self.border_mode, | |
dim_ordering=self.dim_ordering, | |
filter_shape=self.W_shape | |
) | |
if self.bias: | |
if self.dim_ordering == 'th': | |
output += K.reshape(self.b, (1, self.nb_filter, 1, 1, 1)) | |
elif self.dim_ordering == 'tf': | |
output += K.reshape(self.b, (1, 1, 1, 1, self.nb_filter)) | |
else: | |
raise ValueError('Invalid dim_ordering:', self.dim_ordering) | |
output = self.activation(output) | |
return output | |
def get_config(self): | |
config = { | |
'nb_filter': self.nb_filter, | |
'nb_seed_filter': self.nb_seed_filter, | |
'kernel_dim1': self.kernel_dim1, | |
'kernel_dim2': self.kernel_dim2, | |
'kernel_dim3': self.kernel_dim3, | |
'dim_ordering': self.dim_ordering, | |
'init': self.init.__name__, | |
'activation': self.activation.__name__, | |
'border_mode': self.border_mode, | |
'subsample': self.subsample, | |
'S_regularizer': self.S_regularizer.get_config() if self.S_regularizer else None, | |
'A_regularizer': self.A_regularizer.get_config() if self.A_regularizer else None, | |
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None, | |
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, | |
'S_constraint': self.S_constraint.get_config() if self.S_constraint else None, | |
'A_constraint': self.A_constraint.get_config() if self.A_constraint else None, | |
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, | |
'bias': self.bias | |
} | |
base_config = super(SharedFilterConvolution3D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment