Created
November 15, 2018 08:11
-
-
Save jimmy15923/9c05b2064bc6de462d21df6285164026 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import six | |
def _bn_relu(input): | |
"""Helper to build a BN -> relu block (by @raghakot).""" | |
norm = tf.keras.layers.BatchNormalization(axis=CHANNEL_AXIS)(input) | |
return tf.keras.layers.Activation("relu")(norm) #Activation("relu")(norm) | |
def _conv_bn_relu3D(**conv_params): | |
filters = conv_params["filters"] | |
kernel_size = conv_params["kernel_size"] | |
strides = conv_params.setdefault("strides", (1, 1, 1)) | |
kernel_initializer = conv_params.setdefault( | |
"kernel_initializer", "he_normal") | |
padding = conv_params.setdefault("padding", "same") | |
kernel_regularizer = conv_params.setdefault("kernel_regularizer", | |
tf.keras.regularizers.l2(1e-8)) | |
def f(input): | |
conv = tf.keras.layers.Conv3D(filters=filters, kernel_size=kernel_size, | |
strides=strides, kernel_initializer=kernel_initializer, | |
padding=padding, | |
kernel_regularizer=kernel_regularizer)(input) | |
return _bn_relu(conv) | |
return f | |
def _shortcut3d(input, residual): | |
"""3D shortcut to match input and residual and merges them with "sum".""" | |
stride_dim1 = input.shape[DIM1_AXIS].value // residual.shape[DIM1_AXIS].value | |
stride_dim2 = input.shape[DIM2_AXIS].value // residual.shape[DIM2_AXIS].value | |
stride_dim3 = input.shape[DIM3_AXIS].value // residual.shape[DIM3_AXIS].value | |
equal_channels = residual.shape[CHANNEL_AXIS].value == input.shape[CHANNEL_AXIS].value | |
shortcut = input | |
print(shortcut) | |
if stride_dim1 > 1 or stride_dim2 > 1 or stride_dim3 > 1 or not equal_channels: | |
shortcut = tf.keras.layers.Conv3D( | |
filters=residual.shape[CHANNEL_AXIS].value, | |
kernel_size=(1, 1, 1), | |
strides=(stride_dim1, stride_dim2, stride_dim3), | |
kernel_initializer="he_normal", padding="valid", | |
kernel_regularizer=tf.keras.regularizers.l2(1e-4) | |
)(input) | |
print(shortcut) | |
return tf.keras.layers.add([shortcut, residual]) | |
def _residual_block3d(block_function, filters, kernel_regularizer, repetitions, | |
is_first_layer=False): | |
def f(input): | |
for i in range(repetitions): | |
strides = (1, 1, 1) | |
if i == 0 and not is_first_layer: | |
strides = (2, 2, 2) | |
input = block_function(filters=filters, strides=strides, | |
kernel_regularizer=kernel_regularizer, | |
is_first_block_of_first_layer=( | |
is_first_layer and i == 0) | |
)(input) | |
return input | |
return f | |
def basic_block(filters, strides=(1, 1, 1), kernel_regularizer=tf.keras.regularizers.l2(1e-8), | |
is_first_block_of_first_layer=False): | |
"""Basic 3 X 3 X 3 convolution blocks. Extended from raghakot's 2D impl.""" | |
def f(input): | |
if is_first_block_of_first_layer: | |
# don't repeat bn->relu since we just did bn->relu->maxpool | |
conv1 = tf.keras.layers.Conv3D(filters=filters, kernel_size=(3, 3, 3), | |
strides=strides, padding="same", | |
kernel_initializer="he_normal", | |
kernel_regularizer=kernel_regularizer | |
)(input) | |
else: | |
conv1 = _conv_bn_relu3D(filters=filters, | |
kernel_size=(3, 3, 3), | |
strides=strides, | |
kernel_regularizer=kernel_regularizer | |
)(input) | |
residual = _conv_bn_relu3D(filters=filters, kernel_size=(3, 3, 3), | |
kernel_regularizer=kernel_regularizer | |
)(conv1) | |
return _shortcut3d(input, residual) | |
return f | |
def bottleneck(filters, strides=(1, 1, 1), kernel_regularizer=tf.keras.regularizers.l2(1e-8), | |
is_first_block_of_first_layer=False): | |
"""Basic 3 X 3 X 3 convolution blocks. Extended from raghakot's 2D impl.""" | |
def f(input): | |
if is_first_block_of_first_layer: | |
# don't repeat bn->relu since we just did bn->relu->maxpool | |
conv_1_1 = tf.keras.layers.Conv3D(filters=filters, kernel_size=(1, 1, 1), | |
strides=strides, padding="same", | |
kernel_initializer="he_normal", | |
kernel_regularizer=kernel_regularizer | |
)(input) | |
else: | |
conv_1_1 = _conv_bn_relu3D(filters=filters, kernel_size=(1, 1, 1), | |
strides=strides, | |
kernel_regularizer=kernel_regularizer | |
)(input) | |
conv_3_3 = _conv_bn_relu3D(filters=filters, kernel_size=(3, 3, 3), | |
kernel_regularizer=kernel_regularizer | |
)(conv_1_1) | |
residual = _conv_bn_relu3D(filters=filters * 4, kernel_size=(1, 1, 1), | |
kernel_regularizer=kernel_regularizer | |
)(conv_3_3) | |
return _shortcut3d(input, residual) | |
return f | |
def _handle_data_format(): | |
global DIM1_AXIS | |
global DIM2_AXIS | |
global DIM3_AXIS | |
global CHANNEL_AXIS | |
if tf.keras.backend.image_data_format() == 'channels_last': | |
DIM1_AXIS = 1 | |
DIM2_AXIS = 2 | |
DIM3_AXIS = 3 | |
CHANNEL_AXIS = 4 | |
else: | |
CHANNEL_AXIS = 1 | |
DIM1_AXIS = 2 | |
DIM2_AXIS = 3 | |
DIM3_AXIS = 4 | |
def _get_block(identifier): | |
if isinstance(identifier, six.string_types): | |
res = globals().get(identifier) | |
if not res: | |
raise ValueError('Invalid {}'.format(identifier)) | |
return res | |
return identifier | |
class Resnet3DBuilder(object): | |
"""ResNet3D.""" | |
@staticmethod | |
def build(input_shape, num_outputs, block_fn, repetitions, reg_factor): | |
"""Instantiate a vanilla ResNet3D keras model. | |
# Arguments | |
input_shape: Tuple of input shape in the format | |
(conv_dim1, conv_dim2, conv_dim3, channels) if dim_ordering='tf' | |
(filter, conv_dim1, conv_dim2, conv_dim3) if dim_ordering='th' | |
num_outputs: The number of outputs at the final softmax layer | |
block_fn: Unit block to use {'basic_block', 'bottlenack_block'} | |
repetitions: Repetitions of unit blocks | |
# Returns | |
model: a 3D ResNet model that takes a 5D tensor (volumetric images | |
in batch) as input and returns a 1D vector (prediction) as output. | |
""" | |
_handle_data_format() | |
if len(input_shape) != 4: | |
raise ValueError("Input shape should be a tuple " | |
"(conv_dim1, conv_dim2, conv_dim3, channels) " | |
"for tensorflow as backend or " | |
"(channels, conv_dim1, conv_dim2, conv_dim3) " | |
"for theano as backend") | |
block_fn = _get_block(block_fn) | |
input = tf.keras.layers.Input(shape=input_shape) | |
# first conv | |
conv1 = _conv_bn_relu3D(filters=32, kernel_size=(3, 3, 3), padding = 'same', | |
strides=(1, 1, 1), | |
kernel_regularizer=tf.keras.regularizers.l2(reg_factor) | |
)(input) | |
pool1 = tf.keras.layers.MaxPooling3D(strides=(2, 2, 2), | |
padding="same")(conv1) | |
block = pool1 | |
# repeat blocks | |
filters = 64 | |
for i, r in enumerate(repetitions): | |
block = _residual_block3d(block_fn, filters=filters, | |
kernel_regularizer=tf.keras.regularizers.l2(reg_factor), | |
repetitions=r, is_first_layer=(i == 0) | |
)(block) | |
filters *= 2 | |
# last activation | |
block_output = _bn_relu(block) | |
# average pool and classification | |
pool2 = tf.keras.layers.AveragePooling3D(pool_size=(block.shape[DIM1_AXIS], | |
block.shape[DIM2_AXIS], | |
block.shape[DIM3_AXIS]), | |
strides=(1, 1, 1))(block_output) | |
flatten1 = tf.keras.layers.Flatten()(pool2) | |
#flatten1 = GlobalAveragePooling3D()(block_output) | |
if num_outputs > 1: | |
dense = tf.keras.layers.Dense(units=num_outputs, | |
kernel_initializer="he_normal", | |
activation="softmax", | |
kernel_regularizer=tf.keras.regularizers.l2(reg_factor))(flatten1) | |
else: | |
dense = tf.keras.layers.Dense(units=num_outputs, | |
kernel_initializer="he_normal", | |
activation="sigmoid", | |
kernel_regularizer=tf.keras.regularizers.l2(reg_factor))(flatten1) | |
model = tf.keras.models.Model(inputs=input, outputs=dense) | |
return model | |
@staticmethod | |
def build_resnet_18(input_shape, num_outputs, reg_factor=1e-4): | |
"""Build resnet 18.""" | |
return Resnet3DBuilder.build(input_shape, num_outputs, basic_block, | |
[2, 2, 2, 2], reg_factor=reg_factor) | |
@staticmethod | |
def build_resnet_34(input_shape, num_outputs, reg_factor=1e-4): | |
"""Build resnet 34.""" | |
return Resnet3DBuilder.build(input_shape, num_outputs, basic_block, | |
[3, 4, 6, 3], reg_factor=reg_factor) | |
@staticmethod | |
def build_resnet_50(input_shape, num_outputs, reg_factor=1e-4): | |
"""Build resnet 50.""" | |
return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck, | |
[3, 4, 6, 3], reg_factor=reg_factor) | |
@staticmethod | |
def build_resnet_101(input_shape, num_outputs, reg_factor=1e-4): | |
"""Build resnet 101.""" | |
return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck, | |
[3, 4, 23, 3], reg_factor=reg_factor) | |
@staticmethod | |
def build_resnet_152(input_shape, num_outputs, reg_factor=1e-4): | |
"""Build resnet 152.""" | |
return Resnet3DBuilder.build(input_shape, num_outputs, bottleneck, | |
[3, 8, 36, 3], reg_factor=reg_factor) | |
if __name__ == "__main__": | |
# generate fake data | |
x = np.random.randint(0, 1, size=(1000, 256, 256, 16, 1)) | |
y = np.random.choice([0, 1], size=(1000,)) | |
y = tf.keras.utils.to_categorical(y, 2) | |
# build model and copile it | |
res_model = Resnet3DBuilder.build_resnet_18(input_shape=(256, 256, 16, 1), num_outputs=2) | |
res_model.compile(tf.keras.optimizers.Adam(), loss="categorical_crossentropy") | |
res_model.summary() | |
# train ResMNet 3D model | |
res_model.fit(x, y, batch_size=16, epochs=10) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment