Created January 19, 2017
V-Net in Keras and tensorflow
"import numpy\n",
"import warnings\n",
"from keras.layers import Convolution3D, Input, merge, RepeatVector, Activation\n",
"from keras.models import Model\n",
"from keras.layers.advanced_activations import PReLU\n",
"from keras import activations, initializations, regularizers\n",
"from keras.engine import Layer, InputSpec\n",
"from keras.utils.np_utils import conv_output_length\n",
"from keras.optimizers import Adam\n",
"from keras.callbacks import ModelCheckpoint\n",
"import keras.backend as K\n",
"from keras.engine.topology import Layer\n",
"import functools\n",
"import tensorflow as tf\n",
"import pickle\n",
"import time"
"with open('../data/PROMISE2012/train_data.p3', 'rb') as f:\n",
" X, y = pickle.load(f)\n",
" \n",
"X = X.reshape(X.shape + (1,)).astype(numpy.float32)\n",
"y = y.reshape(y.shape + (1,))\n",
"y = numpy.concatenate([y, ~y], axis=4)\n",
"class Deconvolution3D(Layer):\n",
" def __init__(self, nb_filter, kernel_dims, output_shape, subsample):\n",
" self.nb_filter = nb_filter\n",
" self.kernel_dims = kernel_dims\n",
" self.strides = (1,) + subsample + (1,)\n",
" self.output_shape_ = output_shape\n",
" assert K.backend() == 'tensorflow'\n",
" super(Deconvolution3D, self).__init__()\n",
" \n",
" def build(self, input_shape):\n",
" assert len(input_shape) == 5\n",
" self.input_shape_ = input_shape\n",
" W_shape = self.kernel_dims + (self.nb_filter, input_shape[4], )\n",
" self.W = self.add_weight(W_shape,\n",
" initializer=functools.partial(initializations.glorot_uniform,dim_ordering='tf'),\n",
" name='{}_W'.format(\n",
" self.b = self.add_weight((1,1,1,self.nb_filter,), initializer='zero', name='{}_b'.format(\n",
" self.built = True\n",
" def get_output_shape_for(self, input_shape):\n",
" return (None, ) + self.output_shape_[1:]\n",
" def call(self, x, mask=None):\n",
" return tf.nn.conv3d_transpose(x, self.W, output_shape=self.output_shape_,\n",
" strides=self.strides, padding='SAME', + self.b\n",
" def get_config(self):\n",
" base_config = super(Deconvolution3D, self).get_config().copy()\n",
" base_config['output_shape'] = self.output_shape_\n",
" return base_config"
"from keras import backend as K\n",
"from keras.engine import Layer\n",
"class Softmax(Layer):\n",
" def __init__(self, axis=-1,**kwargs):\n",
" self.axis=axis\n",
" super(Softmax, self).__init__(**kwargs)\n",
" def build(self,input_shape):\n",
" pass\n",
" def call(self, x,mask=None):\n",
" e = K.exp(x - K.max(x, axis=self.axis, keepdims=True))\n",
" s = K.sum(e, axis=self.axis, keepdims=True)\n",
" return e / s\n",
" def get_output_shape_for(self, input_shape):\n",
" return input_shape"
"def downward_layer(input_layer, n_convolutions, n_output_channels):\n",
" inl = input_layer\n",
" for _ in range(n_convolutions-1):\n",
" inl = PReLU()(\n",
" Convolution3D(n_output_channels // 2, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" )\n",
" conv = Convolution3D(n_output_channels // 2, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" add = merge([conv, input_layer], mode='sum')\n",
" downsample = Convolution3D(n_output_channels, 2,2,2, subsample=(2,2,2))(add)\n",
" prelu = PReLU()(downsample)\n",
" return prelu, add"
"def upward_layer(input0 ,input1, n_convolutions, n_output_channels):\n",
" merged = merge([input0, input1], mode='concat', concat_axis=4)\n",
" inl = merged\n",
" for _ in range(n_convolutions-1):\n",
" inl = PReLU()(\n",
" Convolution3D(n_output_channels * 4, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" )\n",
" conv = Convolution3D(n_output_channels * 4, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" add = merge([conv, merged], mode='sum')\n",
" shape = add.get_shape().as_list()\n",
" new_shape = (1, shape[1] * 2, shape[2] * 2, shape[3] * 2, n_output_channels)\n",
" upsample = Deconvolution3D(n_output_channels, (4,4,4), new_shape, subsample=(2,2,2))(add)\n",
" return PReLU()(upsample)"
"# Layer 1\n",
"input_layer = Input(shape=(128, 128, 64, 1), name='data')\n",
"conv_1 = Convolution3D(16, 5, 5, 5, border_mode='same', dim_ordering='tf')(input_layer)\n",
"repeat_1 = merge([input_layer] * 16, mode='concat')\n",
"add_1 = merge([conv_1, repeat_1], mode='sum')\n",
"prelu_1_1 = PReLU()(add_1)\n",
"downsample_1 = Convolution3D(32, 2,2,2, subsample=(2,2,2))(prelu_1_1)\n",
"prelu_1_2 = PReLU()(downsample_1)\n",
"# Layer 2,3,4\n",
"out2, left2 = downward_layer(prelu_1_2, 2, 64)\n",
"out3, left3 = downward_layer(out2, 2, 128)\n",
"out4, left4 = downward_layer(out3, 2, 256)\n",
"# Layer 5\n",
"conv_5_1 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(out4)\n",
"prelu_5_1 = PReLU()(conv_5_1)\n",
"conv_5_2 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(prelu_5_1)\n",
"prelu_5_2 = PReLU()(conv_5_2)\n",
"conv_5_3 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(prelu_5_2)\n",
"add_5 = merge([conv_5_3, out4], mode='sum')\n",
"prelu_5_1 = PReLU()(add_5)\n",
"downsample_5 = Deconvolution3D(128, (2,2,2), (1, 16, 16, 8, 128), subsample=(2,2,2))(prelu_5_1)\n",
"prelu_5_2 = PReLU()(downsample_5)\n",
"#Layer 6,7,8\n",
"out6 = upward_layer(prelu_5_2, left4, 3, 64)\n",
"out7 = upward_layer(out6, left3, 3, 32)\n",
"out8 = upward_layer(out7, left2, 2, 16)\n",
"#Layer 9\n",
"merged_9 = merge([out8, add_1], mode='concat', concat_axis=4)\n",
"conv_9_1 = Convolution3D(32, 5, 5, 5, border_mode='same', dim_ordering='tf')(merged_9)\n",
"add_9 = merge([conv_9_1, merged_9], mode='sum')\n",
"conv_9_2 = Convolution3D(2, 1, 1, 1, border_mode='same', dim_ordering='tf')(add_9)\n",
"softmax = Softmax()(conv_9_2)\n",
"model = Model(input_layer, softmax)"
"Layer (type) Output Shape Param # Connected to \n",
"data (InputLayer) (None, 128, 128, 64, 1) 0 \n",
"convolution3d_1 (Convolution3D) (None, 128, 128, 64, 16) 2016 data[0][0] \n",
"merge_1 (Merge) (None, 128, 128, 64, 16) 0 data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
"merge_2 (Merge) (None, 128, 128, 64, 16) 0 convolution3d_1[0][0] \n",
" merge_1[0][0] \n",
"prelu_1 (PReLU) (None, 128, 128, 64, 16) 16777216 merge_2[0][0] \n",
"convolution3d_2 (Convolution3D) (None, 64, 64, 32, 32) 4128 prelu_1[0][0] \n",
"prelu_2 (PReLU) (None, 64, 64, 32, 32) 4194304 convolution3d_2[0][0] \n",
"convolution3d_3 (Convolution3D) (None, 64, 64, 32, 32) 128032 prelu_2[0][0] \n",
"prelu_3 (PReLU) (None, 64, 64, 32, 32) 4194304 convolution3d_3[0][0] \n",
"convolution3d_4 (Convolution3D) (None, 64, 64, 32, 32) 128032 prelu_3[0][0] \n",
"merge_3 (Merge) (None, 64, 64, 32, 32) 0 convolution3d_4[0][0] \n",
" prelu_2[0][0] \n",
"convolution3d_5 (Convolution3D) (None, 32, 32, 16, 64) 16448 merge_3[0][0] \n",
"prelu_4 (PReLU) (None, 32, 32, 16, 64) 1048576 convolution3d_5[0][0] \n",
"convolution3d_6 (Convolution3D) (None, 32, 32, 16, 64) 512064 prelu_4[0][0] \n",
"prelu_5 (PReLU) (None, 32, 32, 16, 64) 1048576 convolution3d_6[0][0] \n",
"convolution3d_7 (Convolution3D) (None, 32, 32, 16, 64) 512064 prelu_5[0][0] \n",
"merge_4 (Merge) (None, 32, 32, 16, 64) 0 convolution3d_7[0][0] \n",
" prelu_4[0][0] \n",
"convolution3d_8 (Convolution3D) (None, 16, 16, 8, 128) 65664 merge_4[0][0] \n",
"prelu_6 (PReLU) (None, 16, 16, 8, 128) 262144 convolution3d_8[0][0] \n",
"convolution3d_9 (Convolution3D) (None, 16, 16, 8, 128) 2048128 prelu_6[0][0] \n",
"prelu_7 (PReLU) (None, 16, 16, 8, 128) 262144 convolution3d_9[0][0] \n",
"convolution3d_10 (Convolution3D) (None, 16, 16, 8, 128) 2048128 prelu_7[0][0] \n",
"merge_5 (Merge) (None, 16, 16, 8, 128) 0 convolution3d_10[0][0] \n",
" prelu_6[0][0] \n",
"convolution3d_11 (Convolution3D) (None, 8, 8, 4, 256) 262400 merge_5[0][0] \n",
"prelu_8 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_11[0][0] \n",
"convolution3d_12 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_8[0][0] \n",
"prelu_9 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_12[0][0] \n",
"convolution3d_13 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_9[0][0] \n",
"prelu_10 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_13[0][0] \n",
"convolution3d_14 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_10[0][0] \n",
"merge_6 (Merge) (None, 8, 8, 4, 256) 0 convolution3d_14[0][0] \n",
" prelu_8[0][0] \n",
"prelu_11 (PReLU) (None, 8, 8, 4, 256) 65536 merge_6[0][0] \n",
"deconvolution3d_1 (Deconvolution3D) (None, 16, 16, 8, 128) 262272 prelu_11[0][0] \n",
"prelu_12 (PReLU) (None, 16, 16, 8, 128) 262144 deconvolution3d_1[0][0] \n",
"merge_7 (Merge) (None, 16, 16, 8, 256) 0 prelu_12[0][0] \n",
" merge_5[0][0] \n",
"convolution3d_15 (Convolution3D) (None, 16, 16, 8, 256) 8192256 merge_7[0][0] \n",
"prelu_13 (PReLU) (None, 16, 16, 8, 256) 524288 convolution3d_15[0][0] \n",
"convolution3d_16 (Convolution3D) (None, 16, 16, 8, 256) 8192256 prelu_13[0][0] \n",
"prelu_14 (PReLU) (None, 16, 16, 8, 256) 524288 convolution3d_16[0][0] \n",
"convolution3d_17 (Convolution3D) (None, 16, 16, 8, 256) 8192256 prelu_14[0][0] \n",
"merge_8 (Merge) (None, 16, 16, 8, 256) 0 convolution3d_17[0][0] \n",
" merge_7[0][0] \n",
"deconvolution3d_2 (Deconvolution3D) (None, 32, 32, 16, 64) 1048640 merge_8[0][0] \n",
"prelu_15 (PReLU) (None, 32, 32, 16, 64) 1048576 deconvolution3d_2[0][0] \n",
"merge_9 (Merge) (None, 32, 32, 16, 128) 0 prelu_15[0][0] \n",
" merge_4[0][0] \n",
"convolution3d_18 (Convolution3D) (None, 32, 32, 16, 128) 2048128 merge_9[0][0] \n",
"prelu_16 (PReLU) (None, 32, 32, 16, 128) 2097152 convolution3d_18[0][0] \n",
"convolution3d_19 (Convolution3D) (None, 32, 32, 16, 128) 2048128 prelu_16[0][0] \n",
"prelu_17 (PReLU) (None, 32, 32, 16, 128) 2097152 convolution3d_19[0][0] \n",
"convolution3d_20 (Convolution3D) (None, 32, 32, 16, 128) 2048128 prelu_17[0][0] \n",
"merge_10 (Merge) (None, 32, 32, 16, 128) 0 convolution3d_20[0][0] \n",
" merge_9[0][0] \n",
"deconvolution3d_3 (Deconvolution3D) (None, 64, 64, 32, 32) 262176 merge_10[0][0] \n",
"prelu_18 (PReLU) (None, 64, 64, 32, 32) 4194304 deconvolution3d_3[0][0] \n",
"merge_11 (Merge) (None, 64, 64, 32, 64) 0 prelu_18[0][0] \n",
" merge_3[0][0] \n",
"convolution3d_21 (Convolution3D) (None, 64, 64, 32, 64) 512064 merge_11[0][0] \n",
"prelu_19 (PReLU) (None, 64, 64, 32, 64) 8388608 convolution3d_21[0][0] \n",
"convolution3d_22 (Convolution3D) (None, 64, 64, 32, 64) 512064 prelu_19[0][0] \n",
"merge_12 (Merge) (None, 64, 64, 32, 64) 0 convolution3d_22[0][0] \n",
" merge_11[0][0] \n",
"deconvolution3d_4 (Deconvolution3D) (None, 128, 128, 64, 16) 65552 merge_12[0][0] \n",
"prelu_20 (PReLU) (None, 128, 128, 64, 16) 16777216 deconvolution3d_4[0][0] \n",
"merge_13 (Merge) (None, 128, 128, 64, 32) 0 prelu_20[0][0] \n",
" merge_2[0][0] \n",
"convolution3d_23 (Convolution3D) (None, 128, 128, 64, 32) 128032 merge_13[0][0] \n",
"merge_14 (Merge) (None, 128, 128, 64, 32) 0 convolution3d_23[0][0] \n",
" merge_13[0][0] \n",
"convolution3d_24 (Convolution3D) (None, 128, 128, 64, 2) 66 merge_14[0][0] \n",
"softmax_1 (Softmax) (None, 128, 128, 64, 2) 0 convolution3d_24[0][0] \n",
"Total params: 122,863,826\n",
"Trainable params: 122,863,826\n",
"Non-trainable params: 0\n",
"def dice_coef(y_true, y_pred):\n",
" y_true_f = K.flatten(y_true)\n",
" y_pred_f = K.reshape(y_pred, (-1, 2))\n",
" intersection = K.mean(y_true_f * y_pred_f[:,0]) + K.mean((1.0 - y_true_f) * y_pred_f[:,1])\n",
" \n",
" return 2. * intersection;\n",
"def dice_coef_loss(y_true, y_pred):\n",
" return -dice_coef(y_true, y_pred)"
"model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])"
"y_pred = model.predict(X[:1,:,:,:,:])\n",
"print(time.time() - t)"
"Epoch 1/20\n"
"model_checkpoint = ModelCheckpoint('unet.hdf5', monitor='loss', save_best_only=True)\n",
", y, batch_size=50, nb_epoch=20, verbose=1)"
"display_name": "Python 3",
"language": "python",
"name": "python3"
"name": "ipython",
"version": 3
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
