Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
3D version of locally connected keras layer
import numpy as np
from keras import backend as K
from keras.legacy import interfaces
import keras
from keras.layers import Layer, InputLayer, Input
import tensorflow as tf
from keras.engine.topology import Node
from keras.utils import conv_utils
class LocallyConnected3D(Layer):
"""
code based on LocallyConnected3D from keras layers:
https://github.com/keras-team/keras/blob/master/keras/layers/local.py
Locally-connected layer for 3D inputs.
The `LocallyConnected3D` layer works similarly
to the `Conv3D` layer, except that weights are unshared,
that is, a different set of filters is applied at each
different patch of the input.
# Examples
```python
# apply a 3x3x3 unshared weights convolution with 64 output filters on a 32x32x32 image
# with `data_format="channels_last"`:
model = Sequential()
model.add(LocallyConnected3D(64, (3, 3, 3), input_shape=(32, 32, 32, 1)))
# now model.output_shape == (None, 30, 30, 30, 64)
# notice that this layer will consume (30*30*30)*(3*3*3*1*64) + (30*30*30)*64 parameters
# add a 3x3x3 unshared weights convolution on top, with 32 output filters:
model.add(LocallyConnected3D(32, (3, 3, 3)))
# now model.output_shape == (None, 28, 28, 28, 32)
```
# Arguments
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
width and height of the 3D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
padding: Currently only support `"valid"` (case-insensitive).
`"same"` will be supported in future.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
activation: Activation function to use
(see [activations](../activations.md)).
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix
(see [initializers](../initializers.md)).
bias_initializer: Initializer for the bias vector
(see [initializers](../initializers.md)).
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix
(see [regularizer](../regularizers.md)).
bias_regularizer: Regularizer function applied to the bias vector
(see [regularizer](../regularizers.md)).
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
(see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to the kernel matrix
(see [constraints](../constraints.md)).
bias_constraint: Constraint function applied to the bias vector
(see [constraints](../constraints.md)).
# Input shape
4D tensor with shape:
`(samples, channels, rows, cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, rows, cols, channels)` if data_format='channels_last'.
# Output shape
4D tensor with shape:
`(samples, filters, new_rows, new_cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, new_rows, new_cols, filters)` if data_format='channels_last'.
`rows` and `cols` values might have changed due to padding.
"""
@interfaces.legacy_conv3d_support
def __init__(self, filters,
kernel_size,
strides=(1, 1, 1),
padding='valid',
data_format=None,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(LocallyConnected3D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(
kernel_size, 3, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 3, 'strides')
self.padding = conv_utils.normalize_padding(padding)
if self.padding != 'valid':
raise ValueError('Invalid border mode for LocallyConnected3D '
'(only "valid" is supported): ' + padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=5)
def build(self, input_shape):
if self.data_format == 'channels_last':
input_row, input_col, input_z = input_shape[1:-1]
input_filter = input_shape[4]
else:
input_row, input_col, input_z = input_shape[2:]
input_filter = input_shape[1]
if input_row is None or input_col is None:
raise ValueError('The spatial dimensions of the inputs to '
' a LocallyConnected3D layer '
'should be fully-defined, but layer received '
'the inputs shape ' + str(input_shape))
output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0],
self.padding, self.strides[0])
output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1],
self.padding, self.strides[1])
output_z = conv_utils.conv_output_length(input_z, self.kernel_size[2],
self.padding, self.strides[2])
self.output_row = output_row
self.output_col = output_col
self.output_z = output_z
self.kernel_shape = (output_row * output_col * output_z,
self.kernel_size[0] *
self.kernel_size[1] *
self.kernel_size[2] * input_filter,
self.filters)
self.kernel = self.add_weight(shape=self.kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(shape=(output_row, output_col, output_z, self.filters),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
if self.data_format == 'channels_first':
self.input_spec = InputSpec(ndim=5, axes={1: input_filter})
else:
self.input_spec = InputSpec(ndim=5, axes={-1: input_filter})
self.built = True
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
rows = input_shape[2]
cols = input_shape[3]
z = input_shape[4]
elif self.data_format == 'channels_last':
rows = input_shape[1]
cols = input_shape[2]
z = input_shape[3]
rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
self.padding, self.strides[0])
cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
self.padding, self.strides[1])
z = conv_utils.conv_output_length(z, self.kernel_size[2],
self.padding, self.strides[2])
if self.data_format == 'channels_first':
return (input_shape[0], self.filters, rows, cols, z)
elif self.data_format == 'channels_last':
return (input_shape[0], rows, cols, z, self.filters)
def call(self, inputs):
output = self.local_conv3d(inputs,
self.kernel,
self.kernel_size,
self.strides,
(self.output_row, self.output_col, self.output_z),
self.data_format)
if self.use_bias:
output = K.bias_add(output, self.bias,
data_format=self.data_format)
output = self.activation(output)
return output
def get_config(self):
config = {
'filters': self.filters,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding,
'data_format': self.data_format,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(
LocallyConnected3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def local_conv3d(self, inputs, kernel, kernel_size, strides, output_shape, data_format=None):
"""Apply 3D conv with un-shared weights.
# Arguments
inputs: 4D tensor with shape:
(batch_size, filters, new_rows, new_cols)
if data_format='channels_first'
or 4D tensor with shape:
(batch_size, new_rows, new_cols, filters)
if data_format='channels_last'.
kernel: the unshared weight for convolution,
with shape (output_items, feature_dim, filters)
kernel_size: a tuple of 2 integers, specifying the
width and height of the 3D convolution window.
strides: a tuple of 2 integers, specifying the strides
of the convolution along the width and height.
output_shape: a tuple with (output_row, output_col)
data_format: the data format, channels_first or channels_last
# Returns
A 4d tensor with shape:
(batch_size, filters, new_rows, new_cols)
if data_format='channels_first'
or 4D tensor with shape:
(batch_size, new_rows, new_cols, filters)
if data_format='channels_last'.
# Raises
ValueError: if `data_format` is neither
`channels_last` or `channels_first`.
"""
if data_format is None:
data_format = K.image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
stride_row, stride_col, stride_z = strides
output_row, output_col, output_z = output_shape
kernel_shape = K.int_shape(kernel)
_, feature_dim, filters = kernel_shape
xs = []
for i in range(output_row):
for j in range(output_col):
for k in range(output_z):
slice_row = slice(i * stride_row,
i * stride_row + kernel_size[0])
slice_col = slice(j * stride_col,
j * stride_col + kernel_size[1])
slice_z = slice(k * stride_z,
k * stride_z + kernel_size[2])
if data_format == 'channels_first':
xs.append(K.reshape(inputs[:, :, slice_row, slice_col, slice_z],
(1, -1, feature_dim)))
else:
xs.append(K.reshape(inputs[:, slice_row, slice_col, slice_z, :],
(1, -1, feature_dim)))
x_aggregate = K.concatenate(xs, axis=0)
output = K.batch_dot(x_aggregate, kernel)
output = K.reshape(output,
(output_row, output_col, output_z, -1, filters))
if data_format == 'channels_first':
output = K.permute_dimensions(output, (3, 4, 0, 1, 2))
else:
output = K.permute_dimensions(output, (3, 0, 1, 2, 4))
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.