Skip to content

Instantly share code, notes, and snippets.

@czotti
Last active November 28, 2018 06:03
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save czotti/b1e34c23a92e64490be83f3b8908bdbe to your computer and use it in GitHub Desktop.
Save czotti/b1e34c23a92e64490be83f3b8908bdbe to your computer and use it in GitHub Desktop.
Tiramisu keras implementation
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tiramisu\n",
"\n",
"This is the implementation of the tiramisu <a href=\"https://arxiv.org/pdf/1611.09326.pdf\">paper</a>.\n",
"\n",
"You can check the original code <a href=\"https://github.com/SimJeg/FC-DenseNet/blob/master/FC-DenseNet.py\">here</a>."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using Theano backend.\n",
"Using cuDNN version 5110 on context None\n",
"Mapped name None to device cuda0: GeForce GTX TITAN X (0000:02:00.0)\n"
]
}
],
"source": [
"from keras.layers import (\n",
" Input,\n",
" Conv2D,\n",
" Conv2DTranspose,\n",
" Activation,\n",
" BatchNormalization,\n",
" SpatialDropout2D,\n",
" UpSampling2D,\n",
" MaxPooling2D,\n",
")\n",
"from keras.models import Model\n",
"from keras.layers.merge import Add, Concatenate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper functions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def layer(x, nb_feature_maps, working_axis=1, kernel=(3, 3)):\n",
" bn = BatchNormalization(axis=working_axis)(x)\n",
" relu = Activation(\"relu\")(bn)\n",
" conv = Conv2D(nb_feature_maps, kernel, padding=\"same\",\n",
" kernel_initializer=\"he_uniform\")(relu)\n",
" drop = SpatialDropout2D(0.25)(conv)\n",
" return drop\n",
"\n",
"def dense_block(x, steps, nb_feature_maps, working_axis=1):\n",
" connections = []\n",
" x_stack = x\n",
" for i in range(steps):\n",
" l = layer(x_stack, nb_feature_maps, working_axis=1)\n",
" connections.append(l)\n",
" x_stack = Concatenate(axis=1)([x_stack, l])\n",
" return x_stack, connections\n",
" \n",
"\n",
"def transition_down(x, nb_feature_maps=16, working_axis=1):\n",
" l = layer(x, nb_feature_maps, working_axis, (1, 1))\n",
" l = MaxPooling2D((2, 2))(l)\n",
" return l\n",
"\n",
"def transition_up(skip, blocks, nb_feature_maps=16, working_axis=1):\n",
" l = Concatenate(axis=working_axis)(blocks)\n",
" l = Conv2DTranspose(nb_feature_maps, (3, 3), strides=(2, 2),\n",
" padding=\"same\", kernel_initializer=\"he_uniform\")(l)\n",
" l = Concatenate(axis=working_axis)([l, skip])\n",
" return l"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Architecture"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"init_feature_maps = 48\n",
"growth_rate = 16\n",
"\n",
"steps = [4, 5, 7, 10, 12]\n",
"last_step = 15\n",
"\n",
"inp = Input((1, 256, 256))\n",
"inp._keras_shape\n",
"\n",
"stack = Conv2D(init_feature_maps, (3, 3), padding=\"same\", kernel_initializer=\"he_uniform\")(inp)\n",
"\n",
"# Encode part\n",
"skip_connection_list = []\n",
"for s in steps:\n",
" # Dense Block\n",
" stack, _ = dense_block(stack,s , growth_rate, working_axis=1)\n",
" skip_connection_list.append(stack)\n",
"\n",
" # Transition Down\n",
" stack = transition_down(stack, stack._keras_shape[1])\n",
"\n",
"skip_connection_list = skip_connection_list[::-1]\n",
"\n",
"# Encoded filtering\n",
"block_to_upsample = []\n",
"for i in range(last_step):\n",
" l = layer(stack, growth_rate, working_axis=1)\n",
" block_to_upsample.append(l)\n",
" stack = Concatenate(axis=1)([stack, l])\n",
" \n",
"# Decode path\n",
"x_stack = stack\n",
"x_block_to_upsample = block_to_upsample\n",
"n_layers_per_block = [last_step, ] + steps[::-1]\n",
"for n_layers, s, skip in zip(n_layers_per_block, steps[::-1], skip_connection_list):\n",
" # Transition Up ( Upsampling + concatenation with the skip connection)\n",
" n_filters_keep = growth_rate * n_layers\n",
" x_stack = transition_up(skip, x_block_to_upsample, n_filters_keep)\n",
"\n",
" # Dense Block\n",
" x_stack, x_block_to_upsample = dense_block(x_stack, s, growth_rate, working_axis=1)\n",
" \n",
"# output layers\n",
"out = Conv2D(4, (1, 1), kernel_initializer=\"he_uniform\", padding=\"same\")(x_stack)\n",
"out = Activation(\"softmax\")(out)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model = Model(inputs=[inp,], outputs=[out,])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Layer kernel_shape output_shape \n",
"conv2d_1 (1, 48, 3, 3) (None, 48, 256, 256) \n",
"conv2d_2 (48, 16, 3, 3) (None, 16, 256, 256) \n",
"conv2d_3 (64, 16, 3, 3) (None, 16, 256, 256) \n",
"conv2d_4 (80, 16, 3, 3) (None, 16, 256, 256) \n",
"conv2d_5 (96, 16, 3, 3) (None, 16, 256, 256) \n",
"conv2d_6 (112, 112, 1, 1) (None, 112, 256, 256) \n",
"conv2d_7 (112, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_8 (128, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_9 (144, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_10 (160, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_11 (176, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_12 (192, 192, 1, 1) (None, 192, 128, 128) \n",
"conv2d_13 (192, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_14 (208, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_15 (224, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_16 (240, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_17 (256, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_18 (272, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_19 (288, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_20 (304, 304, 1, 1) (None, 304, 64, 64) \n",
"conv2d_21 (304, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_22 (320, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_23 (336, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_24 (352, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_25 (368, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_26 (384, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_27 (400, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_28 (416, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_29 (432, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_30 (448, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_31 (464, 464, 1, 1) (None, 464, 32, 32) \n",
"conv2d_32 (464, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_33 (480, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_34 (496, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_35 (512, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_36 (528, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_37 (544, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_38 (560, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_39 (576, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_40 (592, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_41 (608, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_42 (624, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_43 (640, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_44 (656, 656, 1, 1) (None, 656, 16, 16) \n",
"conv2d_45 (656, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_46 (672, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_47 (688, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_48 (704, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_49 (720, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_50 (736, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_51 (752, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_52 (768, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_53 (784, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_54 (800, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_55 (816, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_56 (832, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_57 (848, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_58 (864, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_59 (880, 16, 3, 3) (None, 16, 8, 8) \n",
"conv2d_transpose_1 (240, 240, 3, 3) (None, 240, 16, 16) \n",
"conv2d_60 (896, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_61 (912, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_62 (928, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_63 (944, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_64 (960, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_65 (976, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_66 (992, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_67 (1008, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_68 (1024, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_69 (1040, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_70 (1056, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_71 (1072, 16, 3, 3) (None, 16, 16, 16) \n",
"conv2d_transpose_2 (192, 192, 3, 3) (None, 192, 32, 32) \n",
"conv2d_72 (656, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_73 (672, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_74 (688, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_75 (704, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_76 (720, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_77 (736, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_78 (752, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_79 (768, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_80 (784, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_81 (800, 16, 3, 3) (None, 16, 32, 32) \n",
"conv2d_transpose_3 (160, 160, 3, 3) (None, 160, 64, 64) \n",
"conv2d_82 (464, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_83 (480, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_84 (496, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_85 (512, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_86 (528, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_87 (544, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_88 (560, 16, 3, 3) (None, 16, 64, 64) \n",
"conv2d_transpose_4 (112, 112, 3, 3) (None, 112, 128, 128) \n",
"conv2d_89 (304, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_90 (320, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_91 (336, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_92 (352, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_93 (368, 16, 3, 3) (None, 16, 128, 128) \n",
"conv2d_transpose_5 (80, 80, 3, 3) (None, 80, 256, 256) \n",
"conv2d_94 (192, 16, 3, 3) (None, 16, 256, 256) \n",
"conv2d_95 (208, 16, 3, 3) (None, 16, 256, 256) \n",
"conv2d_96 (224, 16, 3, 3) (None, 16, 256, 256) \n",
"conv2d_97 (240, 16, 3, 3) (None, 16, 256, 256) \n",
"conv2d_98 (256, 4, 1, 1) (None, 4, 256, 256) \n"
]
}
],
"source": [
"layers = [layer for layer in model.layers\n",
" if layer.name.find('conv2d') >= 0]\n",
"\n",
"format = \"{:<25} {:<25} {:25}\"\n",
"print(format.format(\"Layer\", \"kernel_shape\", \"output_shape\"))\n",
"for layer in layers:\n",
" print(format.format(layer.name,\n",
" str(layer.get_weights()[0].transpose(2, 3, 0, 1).shape),\n",
" str(layer.get_output_shape_at(0))))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9,422,500\n"
]
}
],
"source": [
"print(\"{:,}\".format(model.count_params()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
@lemairecarl
Copy link

lemairecarl commented May 18, 2017

Nice work, thanks!

Question: At the end of dense_block(), don't x_stack and connections contain the same data? And in transition_up(), you take the list of tensors coming from connections and you concatenate them together. Couldn't you just use x_stack?

@lemairecarl
Copy link

Other question: in the "Encoded filtering" part, it looks like a dense block. It seems like you could have used dense_block() there. Am I missing something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment