Skip to content

Instantly share code, notes, and snippets.

@JudoWill
Last active July 13, 2020 20:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JudoWill/32bd49c646421c9717cc0fcd0d8fe37f to your computer and use it in GitHub Desktop.
Save JudoWill/32bd49c646421c9717cc0fcd0d8fe37f to your computer and use it in GitHub Desktop.
3D version of locally connected keras layer
import numpy as np
from keras import backend as K
from keras.legacy import interfaces
import keras
from keras.layers import Layer, InputLayer, Input
import tensorflow as tf
from keras.engine.topology import Node
from keras.utils import conv_utils
class LocallyConnected3D(Layer):
"""
code based on LocallyConnected3D from keras layers:
https://github.com/keras-team/keras/blob/master/keras/layers/local.py
Locally-connected layer for 3D inputs.
The `LocallyConnected3D` layer works similarly
to the `Conv3D` layer, except that weights are unshared,
that is, a different set of filters is applied at each
different patch of the input.
# Examples
```python
# apply a 3x3x3 unshared weights convolution with 64 output filters on a 32x32x32 image
# with `data_format="channels_last"`:
model = Sequential()
model.add(LocallyConnected3D(64, (3, 3, 3), input_shape=(32, 32, 32, 1)))
# now model.output_shape == (None, 30, 30, 30, 64)
# notice that this layer will consume (30*30*30)*(3*3*3*1*64) + (30*30*30)*64 parameters
# add a 3x3x3 unshared weights convolution on top, with 32 output filters:
model.add(LocallyConnected3D(32, (3, 3, 3)))
# now model.output_shape == (None, 28, 28, 28, 32)
```
# Arguments
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
width and height of the 3D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
padding: Currently only support `"valid"` (case-insensitive).
`"same"` will be supported in future.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
activation: Activation function to use
(see [activations](../activations.md)).
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix
(see [initializers](../initializers.md)).
bias_initializer: Initializer for the bias vector
(see [initializers](../initializers.md)).
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix
(see [regularizer](../regularizers.md)).
bias_regularizer: Regularizer function applied to the bias vector
(see [regularizer](../regularizers.md)).
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
(see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to the kernel matrix
(see [constraints](../constraints.md)).
bias_constraint: Constraint function applied to the bias vector
(see [constraints](../constraints.md)).
# Input shape
4D tensor with shape:
`(samples, channels, rows, cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, rows, cols, channels)` if data_format='channels_last'.
# Output shape
4D tensor with shape:
`(samples, filters, new_rows, new_cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, new_rows, new_cols, filters)` if data_format='channels_last'.
`rows` and `cols` values might have changed due to padding.
"""
@interfaces.legacy_conv3d_support
def __init__(self, filters,
kernel_size,
strides=(1, 1, 1),
padding='valid',
data_format=None,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(LocallyConnected3D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(
kernel_size, 3, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 3, 'strides')
self.padding = conv_utils.normalize_padding(padding)
if self.padding != 'valid':
raise ValueError('Invalid border mode for LocallyConnected3D '
'(only "valid" is supported): ' + padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=5)
def build(self, input_shape):
if self.data_format == 'channels_last':
input_row, input_col, input_z = input_shape[1:-1]
input_filter = input_shape[4]
else:
input_row, input_col, input_z = input_shape[2:]
input_filter = input_shape[1]
if input_row is None or input_col is None:
raise ValueError('The spatial dimensions of the inputs to '
' a LocallyConnected3D layer '
'should be fully-defined, but layer received '
'the inputs shape ' + str(input_shape))
output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0],
self.padding, self.strides[0])
output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1],
self.padding, self.strides[1])
output_z = conv_utils.conv_output_length(input_z, self.kernel_size[2],
self.padding, self.strides[2])
self.output_row = output_row
self.output_col = output_col
self.output_z = output_z
self.kernel_shape = (output_row * output_col * output_z,
self.kernel_size[0] *
self.kernel_size[1] *
self.kernel_size[2] * input_filter,
self.filters)
self.kernel = self.add_weight(shape=self.kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(shape=(output_row, output_col, output_z, self.filters),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
if self.data_format == 'channels_first':
self.input_spec = InputSpec(ndim=5, axes={1: input_filter})
else:
self.input_spec = InputSpec(ndim=5, axes={-1: input_filter})
self.built = True
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
rows = input_shape[2]
cols = input_shape[3]
z = input_shape[4]
elif self.data_format == 'channels_last':
rows = input_shape[1]
cols = input_shape[2]
z = input_shape[3]
rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
self.padding, self.strides[0])
cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
self.padding, self.strides[1])
z = conv_utils.conv_output_length(z, self.kernel_size[2],
self.padding, self.strides[2])
if self.data_format == 'channels_first':
return (input_shape[0], self.filters, rows, cols, z)
elif self.data_format == 'channels_last':
return (input_shape[0], rows, cols, z, self.filters)
def call(self, inputs):
output = self.local_conv3d(inputs,
self.kernel,
self.kernel_size,
self.strides,
(self.output_row, self.output_col, self.output_z),
self.data_format)
if self.use_bias:
output = K.bias_add(output, self.bias,
data_format=self.data_format)
output = self.activation(output)
return output
def get_config(self):
config = {
'filters': self.filters,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding,
'data_format': self.data_format,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(
LocallyConnected3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def local_conv3d(self, inputs, kernel, kernel_size, strides, output_shape, data_format=None):
"""Apply 3D conv with un-shared weights.
# Arguments
inputs: 4D tensor with shape:
(batch_size, filters, new_rows, new_cols)
if data_format='channels_first'
or 4D tensor with shape:
(batch_size, new_rows, new_cols, filters)
if data_format='channels_last'.
kernel: the unshared weight for convolution,
with shape (output_items, feature_dim, filters)
kernel_size: a tuple of 2 integers, specifying the
width and height of the 3D convolution window.
strides: a tuple of 2 integers, specifying the strides
of the convolution along the width and height.
output_shape: a tuple with (output_row, output_col)
data_format: the data format, channels_first or channels_last
# Returns
A 4d tensor with shape:
(batch_size, filters, new_rows, new_cols)
if data_format='channels_first'
or 4D tensor with shape:
(batch_size, new_rows, new_cols, filters)
if data_format='channels_last'.
# Raises
ValueError: if `data_format` is neither
`channels_last` or `channels_first`.
"""
if data_format is None:
data_format = K.image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
stride_row, stride_col, stride_z = strides
output_row, output_col, output_z = output_shape
kernel_shape = K.int_shape(kernel)
_, feature_dim, filters = kernel_shape
xs = []
for i in range(output_row):
for j in range(output_col):
for k in range(output_z):
slice_row = slice(i * stride_row,
i * stride_row + kernel_size[0])
slice_col = slice(j * stride_col,
j * stride_col + kernel_size[1])
slice_z = slice(k * stride_z,
k * stride_z + kernel_size[2])
if data_format == 'channels_first':
xs.append(K.reshape(inputs[:, :, slice_row, slice_col, slice_z],
(1, -1, feature_dim)))
else:
xs.append(K.reshape(inputs[:, slice_row, slice_col, slice_z, :],
(1, -1, feature_dim)))
x_aggregate = K.concatenate(xs, axis=0)
output = K.batch_dot(x_aggregate, kernel)
output = K.reshape(output,
(output_row, output_col, output_z, -1, filters))
if data_format == 'channels_first':
output = K.permute_dimensions(output, (3, 4, 0, 1, 2))
else:
output = K.permute_dimensions(output, (3, 0, 1, 2, 4))
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment