Skip to content

Instantly share code, notes, and snippets.

@baldassarreFe
Created March 12, 2018 16:36
Show Gist options
  • Save baldassarreFe/363ed9aedd8bac23775aaa1fdf381bbf to your computer and use it in GitHub Desktop.
Save baldassarreFe/363ed9aedd8bac23775aaa1fdf381bbf to your computer and use it in GitHub Desktop.
Mixed-Scale Dense Network
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mixed-Scale Dense Network\n",
"\n",
"PNAS paper: [A mixed-scale dense convolutional neural network for image analysis](https://slidecam-camera.lbl.gov/static/asset/PNAS.pdf)\n",
"\n",
"> Deep convolutional neural networks have been successfully applied to many image-processing problems in recent works. Popular network architectures often add additional operations and connections to the standard architecture to enable training deeper networks. To achieve accurate results in practice, a large number of trainable parameters are often required. Here, we introduce a network architecture based on using dilated convolutions to capture features at different image scales and densely connecting all feature maps with each other. The resulting architecture is able to achieve accurate results with relatively few parameters and consists of a single set of operations, making it easier to implement, train, and apply in practice, and automatically adapts to different problems. We compare results of the proposed network architecture with popular existing architectures for several segmentation problems, showing that the proposed architecture is able to achieve accurate results with fewer parameters, with a reduced risk of overfitting the training data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Class definitions"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.autograd import Variable\n",
"from torch.nn import Module, Sequential, Conv2d\n",
"\n",
"\n",
"class MixedScaleDenseLayer(Module):\n",
" def __init__(self, in_channels, dilations, kernel_size=3):\n",
" super(MixedScaleDenseLayer, self).__init__()\n",
"\n",
" if type(dilations) == int:\n",
" dilations = [j % 10 + 1 for j in range(dilations)]\n",
"\n",
" self.kernel_size = kernel_size\n",
" self.in_channels = in_channels\n",
" self.out_channels = in_channels + len(dilations)\n",
"\n",
" for j, dilation in enumerate(dilations):\n",
" # Equal to: kernel_size + (kernel_size - 1) * (dilation - 1)\n",
" dilated_kernel_size = (kernel_size - 1) * dilation + 1\n",
" padding = dilated_kernel_size // 2\n",
" self.add_module(f'conv_{j}', Conv2d(\n",
" in_channels, 1,\n",
" kernel_size=kernel_size, dilation=dilation, padding=padding\n",
" ))\n",
"\n",
" def forward(self, x):\n",
" return torch.cat((x,) + tuple(c(x) for c in self.children()), dim=1)\n",
"\n",
"\n",
"class MixedScaleDenseNetwork(Sequential):\n",
" def __init__(self, in_channels, out_channels, num_layers, growth_rate, kernel_size=3):\n",
" super(MixedScaleDenseNetwork, self).__init__()\n",
"\n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
"\n",
" current_channels = in_channels\n",
" for i in range(num_layers):\n",
" dilations = [((i * growth_rate + j) % 10) +\n",
" 1 for j in range(growth_rate)]\n",
" l = MixedScaleDenseLayer(current_channels, dilations, kernel_size)\n",
" current_channels = l.out_channels\n",
" self.add_module(f'layer_{i}', l)\n",
"\n",
" self.add_module('last', Conv2d(\n",
" current_channels, out_channels, kernel_size=1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Utils"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def count_parameters(module: Module):\n",
" return sum(p.numel() for p in module.parameters() if p.requires_grad)\n",
"\n",
"\n",
"def count_conv2d(module: Module):\n",
" return len([m for m in module.modules() if isinstance(m, Conv2d)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters: 2204\n",
"Layers: 21\n",
"MixedScaleDenseNetwork(\n",
" (layer_0): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(3, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (conv_1): Conv2d(3, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
" )\n",
" (layer_1): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(5, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
" (conv_1): Conv2d(5, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
" )\n",
" (layer_2): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(7, 1, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5))\n",
" (conv_1): Conv2d(7, 1, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))\n",
" )\n",
" (layer_3): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(9, 1, kernel_size=(3, 3), stride=(1, 1), padding=(7, 7), dilation=(7, 7))\n",
" (conv_1): Conv2d(9, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))\n",
" )\n",
" (layer_4): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(11, 1, kernel_size=(3, 3), stride=(1, 1), padding=(9, 9), dilation=(9, 9))\n",
" (conv_1): Conv2d(11, 1, kernel_size=(3, 3), stride=(1, 1), padding=(10, 10), dilation=(10, 10))\n",
" )\n",
" (layer_5): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(13, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (conv_1): Conv2d(13, 1, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))\n",
" )\n",
" (layer_6): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(15, 1, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3))\n",
" (conv_1): Conv2d(15, 1, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4))\n",
" )\n",
" (layer_7): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(17, 1, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5))\n",
" (conv_1): Conv2d(17, 1, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))\n",
" )\n",
" (layer_8): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(19, 1, kernel_size=(3, 3), stride=(1, 1), padding=(7, 7), dilation=(7, 7))\n",
" (conv_1): Conv2d(19, 1, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(8, 8))\n",
" )\n",
" (layer_9): MixedScaleDenseLayer(\n",
" (conv_0): Conv2d(21, 1, kernel_size=(3, 3), stride=(1, 1), padding=(9, 9), dilation=(9, 9))\n",
" (conv_1): Conv2d(21, 1, kernel_size=(3, 3), stride=(1, 1), padding=(10, 10), dilation=(10, 10))\n",
" )\n",
" (last): Conv2d(23, 1, kernel_size=(1, 1), stride=(1, 1))\n",
")\n"
]
}
],
"source": [
"net = MixedScaleDenseNetwork(3, 1, num_layers=10, growth_rate=2)\n",
"\n",
"print('Parameters:', count_parameters(net))\n",
"print('Layers:', count_conv2d(net))\n",
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 823 ms, sys: 44.4 ms, total: 867 ms\n",
"Wall time: 438 ms\n"
]
}
],
"source": [
"%%time\n",
"x = Variable(torch.rand(32, net.in_channels, 64, 64), volatile=True)\n",
"y = net(x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x: torch.Size([32, 3, 64, 64])\n",
"y: torch.Size([32, 1, 64, 64])\n"
]
}
],
"source": [
"print('x:', x.shape)\n",
"print('y:', y.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comparison with similar network #1"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters: 2204\n",
"Layers: 11\n",
"Net1(\n",
" (layer_0): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_1): Conv2d(5, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_2): Conv2d(7, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_3): Conv2d(9, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_4): Conv2d(11, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_5): Conv2d(13, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_6): Conv2d(15, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_7): Conv2d(17, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_8): Conv2d(19, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_9): Conv2d(21, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (last): Conv2d(23, 1, kernel_size=(1, 1), stride=(1, 1))\n",
")\n"
]
}
],
"source": [
"class Net1(Module):\n",
" def __init__(self, in_channels, out_channels, num_layers, growth_rate, kernel_size=3):\n",
" super(Net1, self).__init__()\n",
"\n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
"\n",
" current_channels = in_channels\n",
" for i in range(num_layers):\n",
" conv = Conv2d(current_channels, growth_rate,\n",
" kernel_size, padding=kernel_size // 2)\n",
" self.add_module(f'layer_{i}', conv)\n",
" current_channels += conv.out_channels\n",
"\n",
" self.add_module('last', Conv2d(\n",
" current_channels, out_channels, kernel_size=1))\n",
"\n",
" def forward(self, x):\n",
" res = x\n",
" for c in self.children():\n",
" res = torch.cat((res, c(res)), dim=1)\n",
" return res\n",
"\n",
"\n",
"net = Net1(3, 1, num_layers=10, growth_rate=2)\n",
"print('Parameters:', count_parameters(net))\n",
"print('Layers:', count_conv2d(net))\n",
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 210 ms, sys: 192 ms, total: 402 ms\n",
"Wall time: 204 ms\n"
]
}
],
"source": [
"%%time\n",
"x = Variable(torch.rand(32, net.in_channels, 64, 64), volatile=True)\n",
"y = net(x)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x: torch.Size([32, 3, 64, 64])\n",
"y: torch.Size([32, 24, 64, 64])\n"
]
}
],
"source": [
"print('x:', x.shape)\n",
"print('y:', y.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comparison with similar network #2"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters: 18254\n",
"Layers: 11\n",
"Net2(\n",
" (layer_0): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_1): Conv2d(5, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_2): Conv2d(7, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_3): Conv2d(9, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_4): Conv2d(11, 13, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_5): Conv2d(13, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_6): Conv2d(15, 17, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_7): Conv2d(17, 19, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_8): Conv2d(19, 21, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (layer_9): Conv2d(21, 23, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (last): Conv2d(23, 1, kernel_size=(1, 1), stride=(1, 1))\n",
")\n"
]
}
],
"source": [
"class Net2(Sequential):\n",
" def __init__(self, in_channels, out_channels, num_layers, growth_rate, kernel_size=3):\n",
" super(Net2, self).__init__()\n",
"\n",
" self.in_channels = in_channels\n",
" self.out_channels = out_channels\n",
"\n",
" current_channels = in_channels\n",
" for i in range(num_layers):\n",
" conv = Conv2d(current_channels, current_channels + growth_rate,\n",
" kernel_size, padding=kernel_size // 2)\n",
" self.add_module(f'layer_{i}', conv)\n",
" current_channels = conv.out_channels\n",
"\n",
" self.add_module('last', Conv2d(\n",
" current_channels, out_channels, kernel_size=1))\n",
"\n",
"\n",
"net = Net2(3, 1, num_layers=10, growth_rate=2)\n",
"print('Parameters:', count_parameters(net))\n",
"print('Layers:', count_conv2d(net))\n",
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 263 ms, sys: 152 ms, total: 414 ms\n",
"Wall time: 213 ms\n"
]
}
],
"source": [
"%%time\n",
"x = Variable(torch.rand(32, net.in_channels, 64, 64), volatile=True)\n",
"y = net(x)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x: torch.Size([32, 3, 64, 64])\n",
"y: torch.Size([32, 1, 64, 64])\n"
]
}
],
"source": [
"print('x:', x.shape)\n",
"print('y:', y.shape)"
]
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/3dd1745be9229ecdc076eeafff58a7e4"
},
"gist": {
"data": {
"description": "Mixed-Scale Dense Network",
"public": true
},
"id": "3dd1745be9229ecdc076eeafff58a7e4"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@EelcoHoogendoorn
Copy link

Interesting example! Let me see if I interpret it correctly; forward evaluation of the MixedScaleDense network is about a factor two slower to evaluate than comparable alternative networks.

However, alternative networks can easily have many more parameters; even if they are more efficient to evaluate despite having more parameters, that does not mean they are more efficient to train / converge. So the MixedScaleDense might still have an edge in training.

What I am after is to get some context about the claim in the original paper about efficiency; while this comparison does demonstrate the point the authors tried to make, they dont make quantitative claim; and this example suggests to me that the incurred inefficiency penalty in existing frameworks is managable; and one hardly needs their proprietary code to make this technique useful in practice.

There is some reading inbetween the lines in their paper that I still havnt quite figured out yet. Generally, I would say it is a good thing if you can implement your technique in existing frameworks. So did they make this remark in their paper in an attempt to justify to themselves the writing their own neural network framework in pycuda from scratch? Or are there actual real world examples where we see more than a factor two performance difference?

@wohe157
Copy link

wohe157 commented Dec 2, 2019

@EelcoHoogendoorn Indeed this example seems to show a comparison of the evaluation time for the MSDNet and a conventional CNN and I agree with you that this suggests that a conventional CNN (with a similar number of parameters) is faster, but would usually need (a lot!) more parameters than the MSDNet.

However, I don't really agree with your question about efficiency. I have tried both this example and the code from D. M. Pelt on a database of +/-1500 images of size 1x256x256 for denoising. The same network (w=1 and d=20) took 5min/epoch with this PyTorch implementation, but only 50sec/epoch with the original code.

Moreover, I'm limited to batch sizes of 8 images with this PyTorch implementation because of the GPU memory usage (6GB), while the original implementation doesn't even use 100MB and only uses my GPU for 55%. Therefore I believe that the original code can be even more optimized to perform even faster.

Lastly, I noticed that both network converge at about the same speed (in epochs) to a similar loss (at an MSE of 0.00018 in my case), but when the PyTorch network stops reducing the loss, the original network goes on to 0.00010 and better. Because of this, I'm not even sure if this PyTorch implementation is the same as the one from the paper.

@RyanPlt
Copy link

RyanPlt commented Apr 22, 2021

@wohe157 It's indeed not the same as in the paper, since it's missing a crucial component in any network, which would be a nonlinearity! In the paper they briefly mention using ReLU, so if you add that, it would probably perform much better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment