Skip to content

Instantly share code, notes, and snippets.

@muyulin
Forked from ravnoor/Keras.ipynb
Created September 2, 2018 13:33
Show Gist options
  • Save muyulin/0cbeeb57237296c71082dba416f427fb to your computer and use it in GitHub Desktop.
Save muyulin/0cbeeb57237296c71082dba416f427fb to your computer and use it in GitHub Desktop.
V-Net in Keras and tensorflow
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy\n",
"import warnings\n",
"from keras.layers import Convolution3D, Input, merge, RepeatVector, Activation\n",
"from keras.models import Model\n",
"from keras.layers.advanced_activations import PReLU\n",
"from keras import activations, initializations, regularizers\n",
"from keras.engine import Layer, InputSpec\n",
"from keras.utils.np_utils import conv_output_length\n",
"from keras.optimizers import Adam\n",
"from keras.callbacks import ModelCheckpoint\n",
"import keras.backend as K\n",
"from keras.engine.topology import Layer\n",
"import functools\n",
"import tensorflow as tf\n",
"import pickle\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"with open('../data/PROMISE2012/train_data.p3', 'rb') as f:\n",
" X, y = pickle.load(f)\n",
" \n",
"X = X.reshape(X.shape + (1,)).astype(numpy.float32)\n",
"y = y.reshape(y.shape + (1,))\n",
"y = numpy.concatenate([y, ~y], axis=4)\n",
"y=y.astype(numpy.float32)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Deconvolution3D(Layer):\n",
" def __init__(self, nb_filter, kernel_dims, output_shape, subsample):\n",
" self.nb_filter = nb_filter\n",
" self.kernel_dims = kernel_dims\n",
" self.strides = (1,) + subsample + (1,)\n",
" self.output_shape_ = output_shape\n",
" assert K.backend() == 'tensorflow'\n",
" super(Deconvolution3D, self).__init__()\n",
" \n",
" def build(self, input_shape):\n",
" assert len(input_shape) == 5\n",
" self.input_shape_ = input_shape\n",
" W_shape = self.kernel_dims + (self.nb_filter, input_shape[4], )\n",
" self.W = self.add_weight(W_shape,\n",
" initializer=functools.partial(initializations.glorot_uniform,dim_ordering='tf'),\n",
" name='{}_W'.format(self.name))\n",
" self.b = self.add_weight((1,1,1,self.nb_filter,), initializer='zero', name='{}_b'.format(self.name))\n",
" self.built = True\n",
"\n",
" def get_output_shape_for(self, input_shape):\n",
" return (None, ) + self.output_shape_[1:]\n",
"\n",
" def call(self, x, mask=None):\n",
" return tf.nn.conv3d_transpose(x, self.W, output_shape=self.output_shape_,\n",
" strides=self.strides, padding='SAME', name=self.name) + self.b\n",
"\n",
" def get_config(self):\n",
" base_config = super(Deconvolution3D, self).get_config().copy()\n",
" base_config['output_shape'] = self.output_shape_\n",
" return base_config"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from keras import backend as K\n",
"from keras.engine import Layer\n",
"\n",
"class Softmax(Layer):\n",
" def __init__(self, axis=-1,**kwargs):\n",
" self.axis=axis\n",
" super(Softmax, self).__init__(**kwargs)\n",
"\n",
" def build(self,input_shape):\n",
" pass\n",
"\n",
" def call(self, x,mask=None):\n",
" e = K.exp(x - K.max(x, axis=self.axis, keepdims=True))\n",
" s = K.sum(e, axis=self.axis, keepdims=True)\n",
" return e / s\n",
"\n",
" def get_output_shape_for(self, input_shape):\n",
" return input_shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def downward_layer(input_layer, n_convolutions, n_output_channels):\n",
" inl = input_layer\n",
" for _ in range(n_convolutions-1):\n",
" inl = PReLU()(\n",
" Convolution3D(n_output_channels // 2, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" )\n",
" conv = Convolution3D(n_output_channels // 2, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" add = merge([conv, input_layer], mode='sum')\n",
" downsample = Convolution3D(n_output_channels, 2,2,2, subsample=(2,2,2))(add)\n",
" prelu = PReLU()(downsample)\n",
" return prelu, add"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def upward_layer(input0 ,input1, n_convolutions, n_output_channels):\n",
" merged = merge([input0, input1], mode='concat', concat_axis=4)\n",
" inl = merged\n",
" for _ in range(n_convolutions-1):\n",
" inl = PReLU()(\n",
" Convolution3D(n_output_channels * 4, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" )\n",
" conv = Convolution3D(n_output_channels * 4, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" add = merge([conv, merged], mode='sum')\n",
" shape = add.get_shape().as_list()\n",
" new_shape = (1, shape[1] * 2, shape[2] * 2, shape[3] * 2, n_output_channels)\n",
" upsample = Deconvolution3D(n_output_channels, (4,4,4), new_shape, subsample=(2,2,2))(add)\n",
" return PReLU()(upsample)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"# Layer 1\n",
"input_layer = Input(shape=(128, 128, 64, 1), name='data')\n",
"conv_1 = Convolution3D(16, 5, 5, 5, border_mode='same', dim_ordering='tf')(input_layer)\n",
"repeat_1 = merge([input_layer] * 16, mode='concat')\n",
"add_1 = merge([conv_1, repeat_1], mode='sum')\n",
"prelu_1_1 = PReLU()(add_1)\n",
"downsample_1 = Convolution3D(32, 2,2,2, subsample=(2,2,2))(prelu_1_1)\n",
"prelu_1_2 = PReLU()(downsample_1)\n",
"\n",
"# Layer 2,3,4\n",
"out2, left2 = downward_layer(prelu_1_2, 2, 64)\n",
"out3, left3 = downward_layer(out2, 2, 128)\n",
"out4, left4 = downward_layer(out3, 2, 256)\n",
"\n",
"# Layer 5\n",
"conv_5_1 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(out4)\n",
"prelu_5_1 = PReLU()(conv_5_1)\n",
"conv_5_2 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(prelu_5_1)\n",
"prelu_5_2 = PReLU()(conv_5_2)\n",
"conv_5_3 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(prelu_5_2)\n",
"add_5 = merge([conv_5_3, out4], mode='sum')\n",
"prelu_5_1 = PReLU()(add_5)\n",
"downsample_5 = Deconvolution3D(128, (2,2,2), (1, 16, 16, 8, 128), subsample=(2,2,2))(prelu_5_1)\n",
"prelu_5_2 = PReLU()(downsample_5)\n",
"\n",
"#Layer 6,7,8\n",
"out6 = upward_layer(prelu_5_2, left4, 3, 64)\n",
"out7 = upward_layer(out6, left3, 3, 32)\n",
"out8 = upward_layer(out7, left2, 2, 16)\n",
"\n",
"#Layer 9\n",
"merged_9 = merge([out8, add_1], mode='concat', concat_axis=4)\n",
"conv_9_1 = Convolution3D(32, 5, 5, 5, border_mode='same', dim_ordering='tf')(merged_9)\n",
"add_9 = merge([conv_9_1, merged_9], mode='sum')\n",
"conv_9_2 = Convolution3D(2, 1, 1, 1, border_mode='same', dim_ordering='tf')(add_9)\n",
"\n",
"softmax = Softmax()(conv_9_2)\n",
"\n",
"model = Model(input_layer, softmax)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"=================================================================================================================\n",
"data (InputLayer) (None, 128, 128, 64, 1) 0 \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_1 (Convolution3D) (None, 128, 128, 64, 16) 2016 data[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_1 (Merge) (None, 128, 128, 64, 16) 0 data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_2 (Merge) (None, 128, 128, 64, 16) 0 convolution3d_1[0][0] \n",
" merge_1[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_1 (PReLU) (None, 128, 128, 64, 16) 16777216 merge_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_2 (Convolution3D) (None, 64, 64, 32, 32) 4128 prelu_1[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_2 (PReLU) (None, 64, 64, 32, 32) 4194304 convolution3d_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_3 (Convolution3D) (None, 64, 64, 32, 32) 128032 prelu_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_3 (PReLU) (None, 64, 64, 32, 32) 4194304 convolution3d_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_4 (Convolution3D) (None, 64, 64, 32, 32) 128032 prelu_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_3 (Merge) (None, 64, 64, 32, 32) 0 convolution3d_4[0][0] \n",
" prelu_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_5 (Convolution3D) (None, 32, 32, 16, 64) 16448 merge_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_4 (PReLU) (None, 32, 32, 16, 64) 1048576 convolution3d_5[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_6 (Convolution3D) (None, 32, 32, 16, 64) 512064 prelu_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_5 (PReLU) (None, 32, 32, 16, 64) 1048576 convolution3d_6[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_7 (Convolution3D) (None, 32, 32, 16, 64) 512064 prelu_5[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_4 (Merge) (None, 32, 32, 16, 64) 0 convolution3d_7[0][0] \n",
" prelu_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_8 (Convolution3D) (None, 16, 16, 8, 128) 65664 merge_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_6 (PReLU) (None, 16, 16, 8, 128) 262144 convolution3d_8[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_9 (Convolution3D) (None, 16, 16, 8, 128) 2048128 prelu_6[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_7 (PReLU) (None, 16, 16, 8, 128) 262144 convolution3d_9[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_10 (Convolution3D) (None, 16, 16, 8, 128) 2048128 prelu_7[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_5 (Merge) (None, 16, 16, 8, 128) 0 convolution3d_10[0][0] \n",
" prelu_6[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_11 (Convolution3D) (None, 8, 8, 4, 256) 262400 merge_5[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_8 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_11[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_12 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_8[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_9 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_12[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_13 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_9[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_10 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_13[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_14 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_10[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_6 (Merge) (None, 8, 8, 4, 256) 0 convolution3d_14[0][0] \n",
" prelu_8[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_11 (PReLU) (None, 8, 8, 4, 256) 65536 merge_6[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"deconvolution3d_1 (Deconvolution3D) (None, 16, 16, 8, 128) 262272 prelu_11[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_12 (PReLU) (None, 16, 16, 8, 128) 262144 deconvolution3d_1[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_7 (Merge) (None, 16, 16, 8, 256) 0 prelu_12[0][0] \n",
" merge_5[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_15 (Convolution3D) (None, 16, 16, 8, 256) 8192256 merge_7[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_13 (PReLU) (None, 16, 16, 8, 256) 524288 convolution3d_15[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_16 (Convolution3D) (None, 16, 16, 8, 256) 8192256 prelu_13[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_14 (PReLU) (None, 16, 16, 8, 256) 524288 convolution3d_16[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_17 (Convolution3D) (None, 16, 16, 8, 256) 8192256 prelu_14[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_8 (Merge) (None, 16, 16, 8, 256) 0 convolution3d_17[0][0] \n",
" merge_7[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"deconvolution3d_2 (Deconvolution3D) (None, 32, 32, 16, 64) 1048640 merge_8[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_15 (PReLU) (None, 32, 32, 16, 64) 1048576 deconvolution3d_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_9 (Merge) (None, 32, 32, 16, 128) 0 prelu_15[0][0] \n",
" merge_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_18 (Convolution3D) (None, 32, 32, 16, 128) 2048128 merge_9[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_16 (PReLU) (None, 32, 32, 16, 128) 2097152 convolution3d_18[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_19 (Convolution3D) (None, 32, 32, 16, 128) 2048128 prelu_16[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_17 (PReLU) (None, 32, 32, 16, 128) 2097152 convolution3d_19[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_20 (Convolution3D) (None, 32, 32, 16, 128) 2048128 prelu_17[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_10 (Merge) (None, 32, 32, 16, 128) 0 convolution3d_20[0][0] \n",
" merge_9[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"deconvolution3d_3 (Deconvolution3D) (None, 64, 64, 32, 32) 262176 merge_10[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_18 (PReLU) (None, 64, 64, 32, 32) 4194304 deconvolution3d_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_11 (Merge) (None, 64, 64, 32, 64) 0 prelu_18[0][0] \n",
" merge_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_21 (Convolution3D) (None, 64, 64, 32, 64) 512064 merge_11[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_19 (PReLU) (None, 64, 64, 32, 64) 8388608 convolution3d_21[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_22 (Convolution3D) (None, 64, 64, 32, 64) 512064 prelu_19[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_12 (Merge) (None, 64, 64, 32, 64) 0 convolution3d_22[0][0] \n",
" merge_11[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"deconvolution3d_4 (Deconvolution3D) (None, 128, 128, 64, 16) 65552 merge_12[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_20 (PReLU) (None, 128, 128, 64, 16) 16777216 deconvolution3d_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_13 (Merge) (None, 128, 128, 64, 32) 0 prelu_20[0][0] \n",
" merge_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_23 (Convolution3D) (None, 128, 128, 64, 32) 128032 merge_13[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_14 (Merge) (None, 128, 128, 64, 32) 0 convolution3d_23[0][0] \n",
" merge_13[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_24 (Convolution3D) (None, 128, 128, 64, 2) 66 merge_14[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"softmax_1 (Softmax) (None, 128, 128, 64, 2) 0 convolution3d_24[0][0] \n",
"=================================================================================================================\n",
"Total params: 122,863,826\n",
"Trainable params: 122,863,826\n",
"Non-trainable params: 0\n",
"_________________________________________________________________________________________________________________\n"
]
}
],
"source": [
"model.summary(line_length=113)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def dice_coef(y_true, y_pred):\n",
" y_true_f = K.flatten(y_true)\n",
" y_pred_f = K.reshape(y_pred, (-1, 2))\n",
" intersection = K.mean(y_true_f * y_pred_f[:,0]) + K.mean((1.0 - y_true_f) * y_pred_f[:,1])\n",
" \n",
" return 2. * intersection;\n",
"\n",
"def dice_coef_loss(y_true, y_pred):\n",
" return -dice_coef(y_true, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"t=time.time()\n",
"y_pred = model.predict(X[:1,:,:,:,:])\n",
"print(time.time() - t)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n"
]
}
],
"source": [
"model_checkpoint = ModelCheckpoint('unet.hdf5', monitor='loss', save_best_only=True)\n",
"\n",
"model.fit(X, y, batch_size=50, nb_epoch=20, verbose=1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment