Skip to content

Instantly share code, notes, and snippets.

@ravnoor
Forked from prhbrt/Keras.ipynb
Created April 24, 2017 02:38
Show Gist options
  • Save ravnoor/a8d26c485cd39c1d9dd21af7c27ac232 to your computer and use it in GitHub Desktop.
Save ravnoor/a8d26c485cd39c1d9dd21af7c27ac232 to your computer and use it in GitHub Desktop.
V-Net in Keras and tensorflow
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy\n",
"import warnings\n",
"from keras.layers import Convolution3D, Input, merge, RepeatVector, Activation\n",
"from keras.models import Model\n",
"from keras.layers.advanced_activations import PReLU\n",
"from keras import activations, initializations, regularizers\n",
"from keras.engine import Layer, InputSpec\n",
"from keras.utils.np_utils import conv_output_length\n",
"from keras.optimizers import Adam\n",
"from keras.callbacks import ModelCheckpoint\n",
"import keras.backend as K\n",
"from keras.engine.topology import Layer\n",
"import functools\n",
"import tensorflow as tf\n",
"import pickle\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"with open('../data/PROMISE2012/train_data.p3', 'rb') as f:\n",
" X, y = pickle.load(f)\n",
" \n",
"X = X.reshape(X.shape + (1,)).astype(numpy.float32)\n",
"y = y.reshape(y.shape + (1,))\n",
"y = numpy.concatenate([y, ~y], axis=4)\n",
"y=y.astype(numpy.float32)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Deconvolution3D(Layer):\n",
" def __init__(self, nb_filter, kernel_dims, output_shape, subsample):\n",
" self.nb_filter = nb_filter\n",
" self.kernel_dims = kernel_dims\n",
" self.strides = (1,) + subsample + (1,)\n",
" self.output_shape_ = output_shape\n",
" assert K.backend() == 'tensorflow'\n",
" super(Deconvolution3D, self).__init__()\n",
" \n",
" def build(self, input_shape):\n",
" assert len(input_shape) == 5\n",
" self.input_shape_ = input_shape\n",
" W_shape = self.kernel_dims + (self.nb_filter, input_shape[4], )\n",
" self.W = self.add_weight(W_shape,\n",
" initializer=functools.partial(initializations.glorot_uniform,dim_ordering='tf'),\n",
" name='{}_W'.format(self.name))\n",
" self.b = self.add_weight((1,1,1,self.nb_filter,), initializer='zero', name='{}_b'.format(self.name))\n",
" self.built = True\n",
"\n",
" def get_output_shape_for(self, input_shape):\n",
" return (None, ) + self.output_shape_[1:]\n",
"\n",
" def call(self, x, mask=None):\n",
" return tf.nn.conv3d_transpose(x, self.W, output_shape=self.output_shape_,\n",
" strides=self.strides, padding='SAME', name=self.name) + self.b\n",
"\n",
" def get_config(self):\n",
" base_config = super(Deconvolution3D, self).get_config().copy()\n",
" base_config['output_shape'] = self.output_shape_\n",
" return base_config"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from keras import backend as K\n",
"from keras.engine import Layer\n",
"\n",
"class Softmax(Layer):\n",
" def __init__(self, axis=-1,**kwargs):\n",
" self.axis=axis\n",
" super(Softmax, self).__init__(**kwargs)\n",
"\n",
" def build(self,input_shape):\n",
" pass\n",
"\n",
" def call(self, x,mask=None):\n",
" e = K.exp(x - K.max(x, axis=self.axis, keepdims=True))\n",
" s = K.sum(e, axis=self.axis, keepdims=True)\n",
" return e / s\n",
"\n",
" def get_output_shape_for(self, input_shape):\n",
" return input_shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def downward_layer(input_layer, n_convolutions, n_output_channels):\n",
" inl = input_layer\n",
" for _ in range(n_convolutions-1):\n",
" inl = PReLU()(\n",
" Convolution3D(n_output_channels // 2, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" )\n",
" conv = Convolution3D(n_output_channels // 2, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" add = merge([conv, input_layer], mode='sum')\n",
" downsample = Convolution3D(n_output_channels, 2,2,2, subsample=(2,2,2))(add)\n",
" prelu = PReLU()(downsample)\n",
" return prelu, add"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def upward_layer(input0 ,input1, n_convolutions, n_output_channels):\n",
" merged = merge([input0, input1], mode='concat', concat_axis=4)\n",
" inl = merged\n",
" for _ in range(n_convolutions-1):\n",
" inl = PReLU()(\n",
" Convolution3D(n_output_channels * 4, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" )\n",
" conv = Convolution3D(n_output_channels * 4, 5, 5, 5, border_mode='same', dim_ordering='tf')(inl)\n",
" add = merge([conv, merged], mode='sum')\n",
" shape = add.get_shape().as_list()\n",
" new_shape = (1, shape[1] * 2, shape[2] * 2, shape[3] * 2, n_output_channels)\n",
" upsample = Deconvolution3D(n_output_channels, (4,4,4), new_shape, subsample=(2,2,2))(add)\n",
" return PReLU()(upsample)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"# Layer 1\n",
"input_layer = Input(shape=(128, 128, 64, 1), name='data')\n",
"conv_1 = Convolution3D(16, 5, 5, 5, border_mode='same', dim_ordering='tf')(input_layer)\n",
"repeat_1 = merge([input_layer] * 16, mode='concat')\n",
"add_1 = merge([conv_1, repeat_1], mode='sum')\n",
"prelu_1_1 = PReLU()(add_1)\n",
"downsample_1 = Convolution3D(32, 2,2,2, subsample=(2,2,2))(prelu_1_1)\n",
"prelu_1_2 = PReLU()(downsample_1)\n",
"\n",
"# Layer 2,3,4\n",
"out2, left2 = downward_layer(prelu_1_2, 2, 64)\n",
"out3, left3 = downward_layer(out2, 2, 128)\n",
"out4, left4 = downward_layer(out3, 2, 256)\n",
"\n",
"# Layer 5\n",
"conv_5_1 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(out4)\n",
"prelu_5_1 = PReLU()(conv_5_1)\n",
"conv_5_2 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(prelu_5_1)\n",
"prelu_5_2 = PReLU()(conv_5_2)\n",
"conv_5_3 = Convolution3D(256, 5, 5, 4, border_mode='same', dim_ordering='tf')(prelu_5_2)\n",
"add_5 = merge([conv_5_3, out4], mode='sum')\n",
"prelu_5_1 = PReLU()(add_5)\n",
"downsample_5 = Deconvolution3D(128, (2,2,2), (1, 16, 16, 8, 128), subsample=(2,2,2))(prelu_5_1)\n",
"prelu_5_2 = PReLU()(downsample_5)\n",
"\n",
"#Layer 6,7,8\n",
"out6 = upward_layer(prelu_5_2, left4, 3, 64)\n",
"out7 = upward_layer(out6, left3, 3, 32)\n",
"out8 = upward_layer(out7, left2, 2, 16)\n",
"\n",
"#Layer 9\n",
"merged_9 = merge([out8, add_1], mode='concat', concat_axis=4)\n",
"conv_9_1 = Convolution3D(32, 5, 5, 5, border_mode='same', dim_ordering='tf')(merged_9)\n",
"add_9 = merge([conv_9_1, merged_9], mode='sum')\n",
"conv_9_2 = Convolution3D(2, 1, 1, 1, border_mode='same', dim_ordering='tf')(add_9)\n",
"\n",
"softmax = Softmax()(conv_9_2)\n",
"\n",
"model = Model(input_layer, softmax)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"=================================================================================================================\n",
"data (InputLayer) (None, 128, 128, 64, 1) 0 \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_1 (Convolution3D) (None, 128, 128, 64, 16) 2016 data[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_1 (Merge) (None, 128, 128, 64, 16) 0 data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
" data[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_2 (Merge) (None, 128, 128, 64, 16) 0 convolution3d_1[0][0] \n",
" merge_1[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_1 (PReLU) (None, 128, 128, 64, 16) 16777216 merge_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_2 (Convolution3D) (None, 64, 64, 32, 32) 4128 prelu_1[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_2 (PReLU) (None, 64, 64, 32, 32) 4194304 convolution3d_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_3 (Convolution3D) (None, 64, 64, 32, 32) 128032 prelu_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_3 (PReLU) (None, 64, 64, 32, 32) 4194304 convolution3d_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_4 (Convolution3D) (None, 64, 64, 32, 32) 128032 prelu_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_3 (Merge) (None, 64, 64, 32, 32) 0 convolution3d_4[0][0] \n",
" prelu_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_5 (Convolution3D) (None, 32, 32, 16, 64) 16448 merge_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_4 (PReLU) (None, 32, 32, 16, 64) 1048576 convolution3d_5[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_6 (Convolution3D) (None, 32, 32, 16, 64) 512064 prelu_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_5 (PReLU) (None, 32, 32, 16, 64) 1048576 convolution3d_6[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_7 (Convolution3D) (None, 32, 32, 16, 64) 512064 prelu_5[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_4 (Merge) (None, 32, 32, 16, 64) 0 convolution3d_7[0][0] \n",
" prelu_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_8 (Convolution3D) (None, 16, 16, 8, 128) 65664 merge_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_6 (PReLU) (None, 16, 16, 8, 128) 262144 convolution3d_8[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_9 (Convolution3D) (None, 16, 16, 8, 128) 2048128 prelu_6[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_7 (PReLU) (None, 16, 16, 8, 128) 262144 convolution3d_9[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_10 (Convolution3D) (None, 16, 16, 8, 128) 2048128 prelu_7[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_5 (Merge) (None, 16, 16, 8, 128) 0 convolution3d_10[0][0] \n",
" prelu_6[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_11 (Convolution3D) (None, 8, 8, 4, 256) 262400 merge_5[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_8 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_11[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_12 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_8[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_9 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_12[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_13 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_9[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_10 (PReLU) (None, 8, 8, 4, 256) 65536 convolution3d_13[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_14 (Convolution3D) (None, 8, 8, 4, 256) 6553856 prelu_10[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_6 (Merge) (None, 8, 8, 4, 256) 0 convolution3d_14[0][0] \n",
" prelu_8[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_11 (PReLU) (None, 8, 8, 4, 256) 65536 merge_6[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"deconvolution3d_1 (Deconvolution3D) (None, 16, 16, 8, 128) 262272 prelu_11[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_12 (PReLU) (None, 16, 16, 8, 128) 262144 deconvolution3d_1[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_7 (Merge) (None, 16, 16, 8, 256) 0 prelu_12[0][0] \n",
" merge_5[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_15 (Convolution3D) (None, 16, 16, 8, 256) 8192256 merge_7[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_13 (PReLU) (None, 16, 16, 8, 256) 524288 convolution3d_15[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_16 (Convolution3D) (None, 16, 16, 8, 256) 8192256 prelu_13[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_14 (PReLU) (None, 16, 16, 8, 256) 524288 convolution3d_16[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_17 (Convolution3D) (None, 16, 16, 8, 256) 8192256 prelu_14[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_8 (Merge) (None, 16, 16, 8, 256) 0 convolution3d_17[0][0] \n",
" merge_7[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"deconvolution3d_2 (Deconvolution3D) (None, 32, 32, 16, 64) 1048640 merge_8[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_15 (PReLU) (None, 32, 32, 16, 64) 1048576 deconvolution3d_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_9 (Merge) (None, 32, 32, 16, 128) 0 prelu_15[0][0] \n",
" merge_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_18 (Convolution3D) (None, 32, 32, 16, 128) 2048128 merge_9[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_16 (PReLU) (None, 32, 32, 16, 128) 2097152 convolution3d_18[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_19 (Convolution3D) (None, 32, 32, 16, 128) 2048128 prelu_16[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_17 (PReLU) (None, 32, 32, 16, 128) 2097152 convolution3d_19[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_20 (Convolution3D) (None, 32, 32, 16, 128) 2048128 prelu_17[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_10 (Merge) (None, 32, 32, 16, 128) 0 convolution3d_20[0][0] \n",
" merge_9[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"deconvolution3d_3 (Deconvolution3D) (None, 64, 64, 32, 32) 262176 merge_10[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_18 (PReLU) (None, 64, 64, 32, 32) 4194304 deconvolution3d_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_11 (Merge) (None, 64, 64, 32, 64) 0 prelu_18[0][0] \n",
" merge_3[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_21 (Convolution3D) (None, 64, 64, 32, 64) 512064 merge_11[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_19 (PReLU) (None, 64, 64, 32, 64) 8388608 convolution3d_21[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_22 (Convolution3D) (None, 64, 64, 32, 64) 512064 prelu_19[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_12 (Merge) (None, 64, 64, 32, 64) 0 convolution3d_22[0][0] \n",
" merge_11[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"deconvolution3d_4 (Deconvolution3D) (None, 128, 128, 64, 16) 65552 merge_12[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"prelu_20 (PReLU) (None, 128, 128, 64, 16) 16777216 deconvolution3d_4[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_13 (Merge) (None, 128, 128, 64, 32) 0 prelu_20[0][0] \n",
" merge_2[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_23 (Convolution3D) (None, 128, 128, 64, 32) 128032 merge_13[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"merge_14 (Merge) (None, 128, 128, 64, 32) 0 convolution3d_23[0][0] \n",
" merge_13[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"convolution3d_24 (Convolution3D) (None, 128, 128, 64, 2) 66 merge_14[0][0] \n",
"_________________________________________________________________________________________________________________\n",
"softmax_1 (Softmax) (None, 128, 128, 64, 2) 0 convolution3d_24[0][0] \n",
"=================================================================================================================\n",
"Total params: 122,863,826\n",
"Trainable params: 122,863,826\n",
"Non-trainable params: 0\n",
"_________________________________________________________________________________________________________________\n"
]
}
],
"source": [
"model.summary(line_length=113)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def dice_coef(y_true, y_pred):\n",
" y_true_f = K.flatten(y_true)\n",
" y_pred_f = K.reshape(y_pred, (-1, 2))\n",
" intersection = K.mean(y_true_f * y_pred_f[:,0]) + K.mean((1.0 - y_true_f) * y_pred_f[:,1])\n",
" \n",
" return 2. * intersection;\n",
"\n",
"def dice_coef_loss(y_true, y_pred):\n",
" return -dice_coef(y_true, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"t=time.time()\n",
"y_pred = model.predict(X[:1,:,:,:,:])\n",
"print(time.time() - t)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n"
]
}
],
"source": [
"model_checkpoint = ModelCheckpoint('unet.hdf5', monitor='loss', save_best_only=True)\n",
"\n",
"model.fit(X, y, batch_size=50, nb_epoch=20, verbose=1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@MinuteswithMetrics
Copy link

What dataset was your model used on?

@pathak-ashutosh
Copy link

It's the PROMISE12 dataset.

@hakanbulu
Copy link

Hello,
Thanks for your implementation.

With;
model.fit(X, y, batch_size=4, epochs=5, verbose=1)
I have following exception;
InvalidArgumentError (see above for traceback): Incompatible shapes: [8388608] vs. [1048576]

  • I am using same dataset with yours.
  • Load data.ipynb works fine.
  • I have tried with different batch_size, i.e. 4,5,10,50
  • The model's summary is exactly same with yours.

I just wonder, do you have any idea, why I am getting this exception?

Thank you, regards.

  • Hakan

@jizhang02
Copy link

hello,
I run U-Net using dice loss, but the predicted images are all white. Do you know what's wrong?
def dice_coef(y_true, y_pred):
smooth = 1
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection +smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) +smooth)

def dice_coef_loss(y_true, y_pred):
print("dice loss")
return 1-dice_coef(y_true, y_pred)
....
model.compile(optimizer = Adam(lr = 1e-5), loss = dice_coef_loss, metrics = ['accuracy'])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment