Skip to content

Instantly share code, notes, and snippets.

@fepegar
Last active April 11, 2024 06:30
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fepegar/1fb865494cb44ac043c3189ec415d411 to your computer and use it in GitHub Desktop.
Save fepegar/1fb865494cb44ac043c3189ec415d411 to your computer and use it in GitHub Desktop.
Understanding U-Net shape constraints
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Understanding U-Net shape constraints\n",
"\n",
"Being new to the CNNs world, when I first read the [constraints of U-Net in NiftyNet](https://github.com/NifTK/NiftyNet/tree/dev/niftynet/network#unet), I didn't know why they existed. Going through the U-Net papers ([original 2D](https://arxiv.org/abs/1505.04597) and [3D version](https://arxiv.org/abs/1606.06650)) and [this one about convolution arithmetic](https://arxiv.org/abs/1603.07285) helped me understand.\n",
"\n",
"If you found this tutorial useful or you have any questions or something to say, **please do leave a comment** on this [gist](https://gist.github.com/fepegar/1fb865494cb44ac043c3189ec415d411)!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## U-Net\n",
"\n",
"U-Net is composed of downsampling and upsampling blocks, followed by two convolutions before the output prediction. Each block is composed of two convolutional layers of kernel size 3 and a downsampling (max pool) or upsampling (up conv) layer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Convolutional layer\n",
"Let's first simulate the effect of a convolutional layer on the shape of the input tensor (called **patch** or **window** in NiftyNet):"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"class LessThanOneError(ValueError):\n",
" def __init__(self, *args, **kwargs):\n",
" ValueError.__init__(self, *args, **kwargs)\n",
"\n",
"def convolve(input_shape, kernel_shape=3):\n",
" input_shape = np.array(input_shape)\n",
" if np.any(input_shape < kernel_shape):\n",
" raise LessThanOneError(f'Input tensor shape ({input_shape})'\n",
" f' is smaller than kernel shape ({kernel_shape})')\n",
" output_shape = (input_shape - kernel_shape) + 1\n",
" return output_shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And test it:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (132, 132, 116)\n",
"Output shape: [130 130 114]\n",
"\n",
"Input shape: (3, 3, 3)\n",
"Output shape: [1 1 1]\n",
"\n",
"Input shape: (2, 2, 2)\n",
"Shape is too small\n",
"\n"
]
}
],
"source": [
"test_shapes = (\n",
" (132, 132, 116),\n",
" (3, 3, 3),\n",
" (2, 2, 2),\n",
")\n",
"\n",
"for shape in test_shapes:\n",
" print('Input shape:', shape)\n",
" try:\n",
" print('Output shape:', convolve(shape))\n",
" except LessThanOneError:\n",
" print('Shape is too small')\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Downsampling block\n",
"\n",
"Let's now define the downsampling block:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class NotEvenError(ValueError):\n",
" def __init__(self, *args, **kwargs):\n",
" ValueError.__init__(self, *args, **kwargs)\n",
"\n",
"def downsample(input_shape):\n",
" if np.any(input_shape % 2):\n",
" raise NotEvenError(f'Input shape ({input_shape}) must be even')\n",
" else:\n",
" output_shape = input_shape // 2\n",
" return output_shape\n",
"\n",
"def downsample_block(input_shape):\n",
" shape = convolve(input_shape)\n",
" shape = convolve(shape)\n",
" output_shape = downsample(shape)\n",
" return output_shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (132, 132, 116)\n",
"Output shape: [64 64 56]\n",
"\n",
"Input shape: (4, 4, 4)\n",
"Too small\n",
"\n",
"Input shape: (5, 5, 5)\n",
"Not compatible\n",
"\n",
"Input shape: (6, 6, 6)\n",
"Output shape: [1 1 1]\n",
"\n"
]
}
],
"source": [
"test_shapes = (\n",
" (132, 132, 116),\n",
" (4, 4, 4),\n",
" (5, 5, 5),\n",
" (6, 6, 6),\n",
")\n",
"\n",
"for shape in test_shapes:\n",
" print('Input shape:', shape)\n",
" try:\n",
" print('Output shape:', downsample_block(shape))\n",
" except LessThanOneError:\n",
" print('Too small')\n",
" except NotEvenError:\n",
" print('Not compatible')\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Upsampling block"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def upsample(input_shape):\n",
" output_shape = 2 * input_shape\n",
" return output_shape\n",
"\n",
"def upsample_block(input_shape):\n",
" shape = convolve(input_shape)\n",
" shape = convolve(shape)\n",
" output_shape = upsample(shape)\n",
" return output_shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### U-Net architecture\n",
"\n",
"And finally, the actual U-Net:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def unet(input_shape, N, debug=False):\n",
" shape = np.array(input_shape)\n",
" if debug:\n",
" print('Input shape:', shape)\n",
" for i in range(N):\n",
" shape = downsample_block(shape)\n",
" if debug:\n",
" print(f'Output shape of downsample block {i+1}: {shape}')\n",
" for i in range(N):\n",
" shape = upsample_block(shape)\n",
" if debug:\n",
" print(f'Output shape of upsample block {i+1}: {shape}')\n",
" shape = convolve(shape)\n",
" shape = convolve(shape)\n",
" \n",
" output_shape = shape\n",
" if debug:\n",
" print('Output shape:', output_shape)\n",
" \n",
" difference = input_shape - output_shape\n",
" if debug:\n",
" print('Input - output:', difference)\n",
" \n",
" border = difference // 2\n",
" if debug:\n",
" print('Border:', border)\n",
" \n",
" return output_shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### U-Net depth\n",
"\n",
"The argument `N` represents the number of downsampling blocks of the network.\n",
"\n",
"For example, for 2D U-Net, `N = 4`:\n",
"<img src=\"https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png\" style=\"width: 600px;\"/>\n",
"\n",
"And for 3D U-Net, `N = 3`:\n",
"<img src=\"https://lmb.informatik.uni-freiburg.de/Publications/2016/CABR16/figureMICCAI2016small.png\" style=\"width: 600px;\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take the shapes mentioned in the papers of both versions:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"args_2d = dict(\n",
" input_shape=(572, 572),\n",
" N=4,\n",
")\n",
"\n",
"args_3d = dict(\n",
" input_shape=(132, 132, 116),\n",
" N=3,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Shapes of 2D U-Net tensors agree with the figure above:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2D U-Net:\n",
"Input shape: [572 572]\n",
"Output shape of downsample block 1: [284 284]\n",
"Output shape of downsample block 2: [140 140]\n",
"Output shape of downsample block 3: [68 68]\n",
"Output shape of downsample block 4: [32 32]\n",
"Output shape of upsample block 1: [56 56]\n",
"Output shape of upsample block 2: [104 104]\n",
"Output shape of upsample block 3: [200 200]\n",
"Output shape of upsample block 4: [392 392]\n",
"Output shape: [388 388]\n",
"Input - output: [184 184]\n",
"Border: [92 92]\n"
]
}
],
"source": [
"print('2D U-Net:')\n",
"output_shape = unet(**args_2d, debug=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What about the 3D version? The intermediate shapes are not present in the paper:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3D U-Net:\n",
"Input shape: [132 132 116]\n",
"Output shape of downsample block 1: [64 64 56]\n",
"Output shape of downsample block 2: [30 30 26]\n",
"Output shape of downsample block 3: [13 13 11]\n",
"Output shape of upsample block 1: [18 18 14]\n",
"Output shape of upsample block 2: [28 28 20]\n",
"Output shape of upsample block 3: [48 48 32]\n",
"Output shape: [44 44 28]\n",
"Input - output: [88 88 88]\n",
"Border: [44 44 44]\n"
]
}
],
"source": [
"print('3D U-Net:')\n",
"output_shape = unet(**args_3d, debug=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try many different shapes using brute force:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def try_shapes(test_shapes, N=3):\n",
" for input_shape in test_shapes:\n",
" print(f'Input shape {input_shape:3}: ', end='')\n",
" try:\n",
" print(unet(input_shape, N=N))\n",
" except LessThanOneError:\n",
" print('Too small')\n",
" except NotEvenError:\n",
" print('Not compatible')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape 0: Too small\n",
"Input shape 1: Too small\n",
"Input shape 2: Too small\n",
"Input shape 3: Too small\n",
"Input shape 4: Too small\n",
"Input shape 5: not compatible\n",
"Input shape 6: Too small\n",
"Input shape 7: not compatible\n",
"Input shape 8: Too small\n",
"Input shape 9: not compatible\n",
"Input shape 10: Too small\n",
"Input shape 11: not compatible\n",
"Input shape 12: Too small\n",
"Input shape 13: not compatible\n",
"Input shape 14: not compatible\n",
"Input shape 15: not compatible\n",
"Input shape 16: Too small\n",
"Input shape 17: not compatible\n",
"Input shape 18: not compatible\n",
"Input shape 19: not compatible\n",
"Input shape 20: Too small\n",
"Input shape 21: not compatible\n",
"Input shape 22: not compatible\n",
"Input shape 23: not compatible\n",
"Input shape 24: Too small\n",
"Input shape 25: not compatible\n",
"Input shape 26: not compatible\n",
"Input shape 27: not compatible\n",
"Input shape 28: Too small\n",
"Input shape 29: not compatible\n",
"Input shape 30: not compatible\n",
"Input shape 31: not compatible\n",
"Input shape 32: not compatible\n",
"Input shape 33: not compatible\n",
"Input shape 34: not compatible\n",
"Input shape 35: not compatible\n",
"Input shape 36: Too small\n",
"Input shape 37: not compatible\n",
"Input shape 38: not compatible\n",
"Input shape 39: not compatible\n",
"Input shape 40: not compatible\n",
"Input shape 41: not compatible\n",
"Input shape 42: not compatible\n",
"Input shape 43: not compatible\n",
"Input shape 44: Too small\n",
"Input shape 45: not compatible\n",
"Input shape 46: not compatible\n",
"Input shape 47: not compatible\n",
"Input shape 48: not compatible\n",
"Input shape 49: not compatible\n",
"Input shape 50: not compatible\n",
"Input shape 51: not compatible\n",
"Input shape 52: Too small\n",
"Input shape 53: not compatible\n",
"Input shape 54: not compatible\n",
"Input shape 55: not compatible\n",
"Input shape 56: not compatible\n",
"Input shape 57: not compatible\n",
"Input shape 58: not compatible\n",
"Input shape 59: not compatible\n",
"Input shape 60: Too small\n",
"Input shape 61: not compatible\n",
"Input shape 62: not compatible\n",
"Input shape 63: not compatible\n",
"Input shape 64: not compatible\n",
"Input shape 65: not compatible\n",
"Input shape 66: not compatible\n",
"Input shape 67: not compatible\n",
"Input shape 68: Too small\n",
"Input shape 69: not compatible\n",
"Input shape 70: not compatible\n",
"Input shape 71: not compatible\n",
"Input shape 72: not compatible\n",
"Input shape 73: not compatible\n",
"Input shape 74: not compatible\n",
"Input shape 75: not compatible\n",
"Input shape 76: Too small\n",
"Input shape 77: not compatible\n",
"Input shape 78: not compatible\n",
"Input shape 79: not compatible\n",
"Input shape 80: not compatible\n",
"Input shape 81: not compatible\n",
"Input shape 82: not compatible\n",
"Input shape 83: not compatible\n",
"Input shape 84: Too small\n",
"Input shape 85: not compatible\n",
"Input shape 86: not compatible\n",
"Input shape 87: not compatible\n",
"Input shape 88: not compatible\n",
"Input shape 89: not compatible\n",
"Input shape 90: not compatible\n",
"Input shape 91: not compatible\n",
"Input shape 92: 4\n",
"Input shape 93: not compatible\n",
"Input shape 94: not compatible\n",
"Input shape 95: not compatible\n",
"Input shape 96: not compatible\n",
"Input shape 97: not compatible\n",
"Input shape 98: not compatible\n",
"Input shape 99: not compatible\n",
"Input shape 100: 12\n",
"Input shape 101: not compatible\n",
"Input shape 102: not compatible\n",
"Input shape 103: not compatible\n",
"Input shape 104: not compatible\n",
"Input shape 105: not compatible\n",
"Input shape 106: not compatible\n",
"Input shape 107: not compatible\n",
"Input shape 108: 20\n",
"Input shape 109: not compatible\n",
"Input shape 110: not compatible\n",
"Input shape 111: not compatible\n",
"Input shape 112: not compatible\n",
"Input shape 113: not compatible\n",
"Input shape 114: not compatible\n",
"Input shape 115: not compatible\n",
"Input shape 116: 28\n",
"Input shape 117: not compatible\n",
"Input shape 118: not compatible\n",
"Input shape 119: not compatible\n",
"Input shape 120: not compatible\n",
"Input shape 121: not compatible\n",
"Input shape 122: not compatible\n",
"Input shape 123: not compatible\n",
"Input shape 124: 36\n",
"Input shape 125: not compatible\n",
"Input shape 126: not compatible\n",
"Input shape 127: not compatible\n",
"Input shape 128: not compatible\n",
"Input shape 129: not compatible\n",
"Input shape 130: not compatible\n",
"Input shape 131: not compatible\n",
"Input shape 132: 44\n"
]
}
],
"source": [
"try_shapes(range(133))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the output and the previous calculations we can see that:\n",
"1. Since the output is 88 pixels smaller than the input, the input size must be larger than 88 to get an output larger than zero\n",
"2. The smallest possible input size is actually 92\n",
"3. Input size must be `92 + M * 8`, where `M >= 0`"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape 92: 4\n",
"Input shape 100: 12\n",
"Input shape 108: 20\n",
"Input shape 116: 28\n",
"Input shape 124: 36\n",
"Input shape 132: 44\n",
"Input shape 140: 52\n",
"Input shape 148: 60\n",
"Input shape 156: 68\n",
"Input shape 164: 76\n",
"Input shape 172: 84\n",
"Input shape 180: 92\n",
"Input shape 188: 100\n",
"Input shape 196: 108\n",
"Input shape 204: 116\n",
"Input shape 212: 124\n",
"Input shape 220: 132\n",
"Input shape 228: 140\n",
"Input shape 236: 148\n",
"Input shape 244: 156\n",
"Input shape 252: 164\n",
"Input shape 260: 172\n",
"Input shape 268: 180\n",
"Input shape 276: 188\n",
"Input shape 284: 196\n",
"Input shape 292: 204\n"
]
}
],
"source": [
"try_shapes(range(92, 300, 8))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"Let:\n",
"```python\n",
"I = input_shape # constrained by the model\n",
"D = input_shape - output_shape # imposed by the model\n",
"B = D / 2 # imposed by the model\n",
"O = output_shape = I - D # inferred from the previous\n",
"```\n",
"\n",
"In the config file:\n",
"* [Input data source](https://niftynet.readthedocs.io/en/dev/config_spec.html#input-data-source-section) section: [`spatial_window_size`](https://niftynet.readthedocs.io/en/dev/config_spec.html#spatial-window-size) must be `I = 92 + M * 8`, where `M >= 0`\n",
"* [`spatial_window_size`](https://niftynet.readthedocs.io/en/dev/config_spec.html#spatial-window-size) in the [`INFERENCE`](https://niftynet.readthedocs.io/en/dev/config_spec.html#inference) section has the same constraints as in the input section\n",
"* In this implementation, the difference of shapes `D` between input `I` and output `O` is 88, so `B` is 44. Therefore: \n",
" * [`volume_padding_size`](https://niftynet.readthedocs.io/en/dev/config_spec.html#volume-padding-size) should be (at least) `B == 44`\n",
"\n",
" * [`border`](https://niftynet.readthedocs.io/en/dev/config_spec.html#border) should be (at least) `B == 44`\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Numbers for the 2D version can be computed in the same way."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## NiftyNet implementation\n",
"\n",
"In NiftyNet, the convolutional layers of U-Net use [`SAME` padding](https://github.com/NifTK/NiftyNet/blob/a81586f217d5bd938e933652375483288f569aca/niftynet/layer/convolution.py#L121). This means that `I` must just be divisible by $2^N$."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@helwilliams
Copy link

This is an extremely helpful explanation with clear examples, thanks! Makes a lot more sense now.

@NicoYuCN
Copy link

NicoYuCN commented Aug 5, 2019

interesting ~

@hzzhangqf0558
Copy link

six six six

@fepegar
Copy link
Author

fepegar commented Oct 27, 2020

six six six

🤔

@kavi-47
Copy link

kavi-47 commented Apr 11, 2024

how to do input 228 without padding

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment