Last active
November 28, 2018 06:03
-
-
Save czotti/b1e34c23a92e64490be83f3b8908bdbe to your computer and use it in GitHub Desktop.
Tiramisu keras implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Tiramisu\n", | |
"\n", | |
"This is the implementation of the tiramisu <a href=\"https://arxiv.org/pdf/1611.09326.pdf\">paper</a>.\n", | |
"\n", | |
"You can check the original code <a href=\"https://github.com/SimJeg/FC-DenseNet/blob/master/FC-DenseNet.py\">here</a>." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using Theano backend.\n", | |
"Using cuDNN version 5110 on context None\n", | |
"Mapped name None to device cuda0: GeForce GTX TITAN X (0000:02:00.0)\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras.layers import (\n", | |
" Input,\n", | |
" Conv2D,\n", | |
" Conv2DTranspose,\n", | |
" Activation,\n", | |
" BatchNormalization,\n", | |
" SpatialDropout2D,\n", | |
" UpSampling2D,\n", | |
" MaxPooling2D,\n", | |
")\n", | |
"from keras.models import Model\n", | |
"from keras.layers.merge import Add, Concatenate" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Helper functions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def layer(x, nb_feature_maps, working_axis=1, kernel=(3, 3)):\n", | |
" bn = BatchNormalization(axis=working_axis)(x)\n", | |
" relu = Activation(\"relu\")(bn)\n", | |
" conv = Conv2D(nb_feature_maps, kernel, padding=\"same\",\n", | |
" kernel_initializer=\"he_uniform\")(relu)\n", | |
" drop = SpatialDropout2D(0.25)(conv)\n", | |
" return drop\n", | |
"\n", | |
"def dense_block(x, steps, nb_feature_maps, working_axis=1):\n", | |
" connections = []\n", | |
" x_stack = x\n", | |
" for i in range(steps):\n", | |
" l = layer(x_stack, nb_feature_maps, working_axis=1)\n", | |
" connections.append(l)\n", | |
" x_stack = Concatenate(axis=1)([x_stack, l])\n", | |
" return x_stack, connections\n", | |
" \n", | |
"\n", | |
"def transition_down(x, nb_feature_maps=16, working_axis=1):\n", | |
" l = layer(x, nb_feature_maps, working_axis, (1, 1))\n", | |
" l = MaxPooling2D((2, 2))(l)\n", | |
" return l\n", | |
"\n", | |
"def transition_up(skip, blocks, nb_feature_maps=16, working_axis=1):\n", | |
" l = Concatenate(axis=working_axis)(blocks)\n", | |
" l = Conv2DTranspose(nb_feature_maps, (3, 3), strides=(2, 2),\n", | |
" padding=\"same\", kernel_initializer=\"he_uniform\")(l)\n", | |
" l = Concatenate(axis=working_axis)([l, skip])\n", | |
" return l" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Architecture" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"init_feature_maps = 48\n", | |
"growth_rate = 16\n", | |
"\n", | |
"steps = [4, 5, 7, 10, 12]\n", | |
"last_step = 15\n", | |
"\n", | |
"inp = Input((1, 256, 256))\n", | |
"inp._keras_shape\n", | |
"\n", | |
"stack = Conv2D(init_feature_maps, (3, 3), padding=\"same\", kernel_initializer=\"he_uniform\")(inp)\n", | |
"\n", | |
"# Encode part\n", | |
"skip_connection_list = []\n", | |
"for s in steps:\n", | |
" # Dense Block\n", | |
" stack, _ = dense_block(stack,s , growth_rate, working_axis=1)\n", | |
" skip_connection_list.append(stack)\n", | |
"\n", | |
" # Transition Down\n", | |
" stack = transition_down(stack, stack._keras_shape[1])\n", | |
"\n", | |
"skip_connection_list = skip_connection_list[::-1]\n", | |
"\n", | |
"# Encoded filtering\n", | |
"block_to_upsample = []\n", | |
"for i in range(last_step):\n", | |
" l = layer(stack, growth_rate, working_axis=1)\n", | |
" block_to_upsample.append(l)\n", | |
" stack = Concatenate(axis=1)([stack, l])\n", | |
" \n", | |
"# Decode path\n", | |
"x_stack = stack\n", | |
"x_block_to_upsample = block_to_upsample\n", | |
"n_layers_per_block = [last_step, ] + steps[::-1]\n", | |
"for n_layers, s, skip in zip(n_layers_per_block, steps[::-1], skip_connection_list):\n", | |
" # Transition Up ( Upsampling + concatenation with the skip connection)\n", | |
" n_filters_keep = growth_rate * n_layers\n", | |
" x_stack = transition_up(skip, x_block_to_upsample, n_filters_keep)\n", | |
"\n", | |
" # Dense Block\n", | |
" x_stack, x_block_to_upsample = dense_block(x_stack, s, growth_rate, working_axis=1)\n", | |
" \n", | |
"# output layers\n", | |
"out = Conv2D(4, (1, 1), kernel_initializer=\"he_uniform\", padding=\"same\")(x_stack)\n", | |
"out = Activation(\"softmax\")(out)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = Model(inputs=[inp,], outputs=[out,])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Layer kernel_shape output_shape \n", | |
"conv2d_1 (1, 48, 3, 3) (None, 48, 256, 256) \n", | |
"conv2d_2 (48, 16, 3, 3) (None, 16, 256, 256) \n", | |
"conv2d_3 (64, 16, 3, 3) (None, 16, 256, 256) \n", | |
"conv2d_4 (80, 16, 3, 3) (None, 16, 256, 256) \n", | |
"conv2d_5 (96, 16, 3, 3) (None, 16, 256, 256) \n", | |
"conv2d_6 (112, 112, 1, 1) (None, 112, 256, 256) \n", | |
"conv2d_7 (112, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_8 (128, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_9 (144, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_10 (160, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_11 (176, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_12 (192, 192, 1, 1) (None, 192, 128, 128) \n", | |
"conv2d_13 (192, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_14 (208, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_15 (224, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_16 (240, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_17 (256, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_18 (272, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_19 (288, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_20 (304, 304, 1, 1) (None, 304, 64, 64) \n", | |
"conv2d_21 (304, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_22 (320, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_23 (336, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_24 (352, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_25 (368, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_26 (384, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_27 (400, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_28 (416, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_29 (432, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_30 (448, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_31 (464, 464, 1, 1) (None, 464, 32, 32) \n", | |
"conv2d_32 (464, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_33 (480, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_34 (496, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_35 (512, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_36 (528, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_37 (544, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_38 (560, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_39 (576, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_40 (592, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_41 (608, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_42 (624, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_43 (640, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_44 (656, 656, 1, 1) (None, 656, 16, 16) \n", | |
"conv2d_45 (656, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_46 (672, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_47 (688, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_48 (704, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_49 (720, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_50 (736, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_51 (752, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_52 (768, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_53 (784, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_54 (800, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_55 (816, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_56 (832, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_57 (848, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_58 (864, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_59 (880, 16, 3, 3) (None, 16, 8, 8) \n", | |
"conv2d_transpose_1 (240, 240, 3, 3) (None, 240, 16, 16) \n", | |
"conv2d_60 (896, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_61 (912, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_62 (928, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_63 (944, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_64 (960, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_65 (976, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_66 (992, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_67 (1008, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_68 (1024, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_69 (1040, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_70 (1056, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_71 (1072, 16, 3, 3) (None, 16, 16, 16) \n", | |
"conv2d_transpose_2 (192, 192, 3, 3) (None, 192, 32, 32) \n", | |
"conv2d_72 (656, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_73 (672, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_74 (688, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_75 (704, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_76 (720, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_77 (736, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_78 (752, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_79 (768, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_80 (784, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_81 (800, 16, 3, 3) (None, 16, 32, 32) \n", | |
"conv2d_transpose_3 (160, 160, 3, 3) (None, 160, 64, 64) \n", | |
"conv2d_82 (464, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_83 (480, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_84 (496, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_85 (512, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_86 (528, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_87 (544, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_88 (560, 16, 3, 3) (None, 16, 64, 64) \n", | |
"conv2d_transpose_4 (112, 112, 3, 3) (None, 112, 128, 128) \n", | |
"conv2d_89 (304, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_90 (320, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_91 (336, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_92 (352, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_93 (368, 16, 3, 3) (None, 16, 128, 128) \n", | |
"conv2d_transpose_5 (80, 80, 3, 3) (None, 80, 256, 256) \n", | |
"conv2d_94 (192, 16, 3, 3) (None, 16, 256, 256) \n", | |
"conv2d_95 (208, 16, 3, 3) (None, 16, 256, 256) \n", | |
"conv2d_96 (224, 16, 3, 3) (None, 16, 256, 256) \n", | |
"conv2d_97 (240, 16, 3, 3) (None, 16, 256, 256) \n", | |
"conv2d_98 (256, 4, 1, 1) (None, 4, 256, 256) \n" | |
] | |
} | |
], | |
"source": [ | |
"layers = [layer for layer in model.layers\n", | |
" if layer.name.find('conv2d') >= 0]\n", | |
"\n", | |
"format = \"{:<25} {:<25} {:25}\"\n", | |
"print(format.format(\"Layer\", \"kernel_shape\", \"output_shape\"))\n", | |
"for layer in layers:\n", | |
" print(format.format(layer.name,\n", | |
" str(layer.get_weights()[0].transpose(2, 3, 0, 1).shape),\n", | |
" str(layer.get_output_shape_at(0))))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"9,422,500\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"{:,}\".format(model.count_params()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Other question: in the "Encoded filtering" part, it looks like a dense block. It seems like you could have used dense_block()
there. Am I missing something?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice work, thanks!
Question: At the end of
dense_block()
, don'tx_stack
andconnections
contain the same data? And intransition_up()
, you take the list of tensors coming fromconnections
and you concatenate them together. Couldn't you just usex_stack
?