Last active
October 24, 2022 16:01
-
-
Save brookisme/52079e106255f75c996d8595cd3988b0 to your computer and use it in GitHub Desktop.
UNET with Squeeze and Excitation Blocks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"from torchsummary import summary" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([1, 64, 568, 568])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" AdaptiveAvgPool2d-1 [-1, 64, 1, 1] 0\n", | |
" Linear-2 [-1, 4] 260\n", | |
" ReLU-3 [-1, 4] 0\n", | |
" Linear-4 [-1, 64] 320\n", | |
" Sigmoid-5 [-1, 64] 0\n", | |
"================================================================\n", | |
"Total params: 580\n", | |
"Trainable params: 580\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"class SqueezeExcitation(nn.Module):\n", | |
" def __init__(self, nb_channels, reduction=16):\n", | |
" super(SqueezeExcitation, self).__init__()\n", | |
" self.nb_channels=nb_channels\n", | |
" self.avg_pool=nn.AdaptiveAvgPool2d(1)\n", | |
" self.fc=nn.Sequential(\n", | |
" nn.Linear(nb_channels, nb_channels // reduction),\n", | |
" nn.ReLU(inplace=True),\n", | |
" nn.Linear(nb_channels // reduction, nb_channels),\n", | |
" nn.Sigmoid())\n", | |
"\n", | |
" \n", | |
" def forward(self, x):\n", | |
" y = self.avg_pool(x).view(-1,self.nb_channels)\n", | |
" y = self.fc(y).view(-1,self.nb_channels,1,1)\n", | |
" return x * y\n", | |
" \n", | |
"\n", | |
"print(SqueezeExcitation(64)(torch.rand(64,568,568)).shape)\n", | |
"summary(SqueezeExcitation(64),input_size=(64,568,568))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(568, 64)\n", | |
"torch.Size([1, 64, 568, 568])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 570, 570] 640\n", | |
" Conv2d-2 [-1, 64, 568, 568] 36,928\n", | |
" BatchNorm2d-3 [-1, 64, 568, 568] 128\n", | |
" AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n", | |
" Linear-5 [-1, 4] 260\n", | |
" ReLU-6 [-1, 4] 0\n", | |
" Linear-7 [-1, 64] 320\n", | |
" Sigmoid-8 [-1, 64] 0\n", | |
" SqueezeExcitation-9 [-1, 64, 568, 568] 0\n", | |
"================================================================\n", | |
"Total params: 38,276\n", | |
"Trainable params: 38,276\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"class ConvBlock(nn.Module):\n", | |
"\n", | |
" def __init__(self,\n", | |
" in_ch,\n", | |
" in_size,\n", | |
" depth=2, \n", | |
" kernel_size=3, \n", | |
" stride=1, \n", | |
" padding=0, \n", | |
" out_ch=None,\n", | |
" bn=True,\n", | |
" se=True,\n", | |
" act='relu',\n", | |
" act_kwargs={}):\n", | |
" super(ConvBlock, self).__init__()\n", | |
" self.out_ch=out_ch or 2*in_ch\n", | |
" self._set_post_processes(self.out_ch,bn,se,act,act_kwargs)\n", | |
" self._set_conv_layers(\n", | |
" depth,\n", | |
" in_ch,\n", | |
" kernel_size,\n", | |
" stride,\n", | |
" padding)\n", | |
" self.out_size=in_size-depth*2*((kernel_size-1)/2-padding)\n", | |
"\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x=self.conv_layers(x)\n", | |
" if self.bn:\n", | |
" x=self.bn(x)\n", | |
" if self.act:\n", | |
" x=self._activation(x)\n", | |
" if self.se:\n", | |
" x=self.se(x)\n", | |
" return x\n", | |
"\n", | |
" \n", | |
" def _set_post_processes(self,out_channels,bn,se,act,act_kwargs):\n", | |
" if bn:\n", | |
" self.bn=nn.BatchNorm2d(out_channels)\n", | |
" else:\n", | |
" self.bn=False\n", | |
" if se:\n", | |
" self.se=SqueezeExcitation(out_channels)\n", | |
" else:\n", | |
" self.se=False\n", | |
" self.act=act\n", | |
" self.act_kwargs=act_kwargs\n", | |
"\n", | |
" \n", | |
" def _set_conv_layers(\n", | |
" self,\n", | |
" depth,\n", | |
" in_ch,\n", | |
" kernel_size,\n", | |
" stride,\n", | |
" padding):\n", | |
" layers=[]\n", | |
" for index in range(depth):\n", | |
" if index!=0:\n", | |
" in_ch=self.out_ch\n", | |
" layers.append(\n", | |
" nn.Conv2d(\n", | |
" in_channels=in_ch,\n", | |
" out_channels=self.out_ch,\n", | |
" kernel_size=kernel_size,\n", | |
" stride=stride,\n", | |
" padding=padding))\n", | |
" self.conv_layers=nn.Sequential(*layers)\n", | |
"\n", | |
" \n", | |
" def _activation(self,x):\n", | |
" return getattr(F,self.act,**self.act_kwargs)(x)\n", | |
"\n", | |
" \n", | |
"conv_block=ConvBlock(1,572,out_ch=64)\n", | |
"print(conv_block.out_size,conv_block.out_ch)\n", | |
"print(conv_block(torch.rand(1,1,572,572)).shape)\n", | |
"summary(ConvBlock(1,572,out_ch=64),input_size=(1,572,572))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(276, 128)\n", | |
"torch.Size([1, 128, 276, 276])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" MaxPool2d-1 [-1, 64, 284, 284] 0\n", | |
" Conv2d-2 [-1, 128, 282, 282] 73,856\n", | |
" Conv2d-3 [-1, 128, 280, 280] 147,584\n", | |
" Conv2d-4 [-1, 128, 278, 278] 147,584\n", | |
" Conv2d-5 [-1, 128, 276, 276] 147,584\n", | |
" BatchNorm2d-6 [-1, 128, 276, 276] 256\n", | |
" AdaptiveAvgPool2d-7 [-1, 128, 1, 1] 0\n", | |
" Linear-8 [-1, 8] 1,032\n", | |
" ReLU-9 [-1, 8] 0\n", | |
" Linear-10 [-1, 128] 1,152\n", | |
" Sigmoid-11 [-1, 128] 0\n", | |
"SqueezeExcitation-12 [-1, 128, 276, 276] 0\n", | |
" ConvBlock-13 [-1, 128, 276, 276] 0\n", | |
"================================================================\n", | |
"Total params: 519,048\n", | |
"Trainable params: 519,048\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"class DownBlock(nn.Module):\n", | |
" \n", | |
" def __init__(self,\n", | |
" in_ch,\n", | |
" in_size,\n", | |
" out_ch=None,\n", | |
" depth=2,\n", | |
" padding=0,\n", | |
" bn=True,\n", | |
" se=True,\n", | |
" act='relu',\n", | |
" act_kwargs={}):\n", | |
" super(DownBlock, self).__init__()\n", | |
" self.out_size=(in_size//2)-depth*(1-padding)*2\n", | |
" self.out_ch=out_ch or in_ch*2\n", | |
" self.down=nn.MaxPool2d(kernel_size=2)\n", | |
" self.conv_block=ConvBlock(\n", | |
" in_ch=in_ch,\n", | |
" out_ch=self.out_ch,\n", | |
" in_size=in_size//2,\n", | |
" depth=depth,\n", | |
" padding=padding,\n", | |
" bn=bn,\n", | |
" se=se,\n", | |
" act=act,\n", | |
" act_kwargs=act_kwargs)\n", | |
"\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x=self.down(x)\n", | |
" return self.conv_block(x)\n", | |
"\n", | |
" \n", | |
"db_out=DownBlock(64,568,depth=4)\n", | |
"print(db_out.out_size,db_out.out_ch)\n", | |
"print(db_out(torch.rand(1,64,568,568)).shape)\n", | |
"summary(db_out,input_size=(64,568,568))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(196, 128)\n", | |
"torch.Size([1, 128, 196, 196])\n" | |
] | |
} | |
], | |
"source": [ | |
"class UpBlock(nn.Module):\n", | |
" \n", | |
" @staticmethod\n", | |
" def cropping(skip_size,size):\n", | |
" return (skip_size-size)//2\n", | |
" \n", | |
" \n", | |
" def __init__(self,\n", | |
" in_ch,\n", | |
" in_size,\n", | |
" out_ch=None,\n", | |
" bilinear=False,\n", | |
" crop=None,\n", | |
" depth=2,\n", | |
" padding=0,\n", | |
" bn=True,\n", | |
" se=True,\n", | |
" act='relu',\n", | |
" act_kwargs={}):\n", | |
" super(UpBlock, self).__init__()\n", | |
" self.crop=crop\n", | |
" self.padding=padding\n", | |
" self.out_size=(in_size*2)-depth*(1-padding)*2\n", | |
" self.out_ch=out_ch or in_ch//2\n", | |
" if bilinear:\n", | |
" self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n", | |
" else:\n", | |
" self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)\n", | |
" self.conv_block=ConvBlock(\n", | |
" in_ch,\n", | |
" self.out_size,\n", | |
" out_ch=self.out_ch,\n", | |
" depth=depth,\n", | |
" padding=padding,\n", | |
" bn=bn,\n", | |
" se=se,\n", | |
" act=act,\n", | |
" act_kwargs=act_kwargs)\n", | |
" \n", | |
" \n", | |
" def forward(self, x, skip):\n", | |
" x = self.up(x)\n", | |
" skip = self._crop(skip,x)\n", | |
" x = torch.cat([skip, x], dim=1)\n", | |
" x = self.conv_block(x)\n", | |
" return x\n", | |
"\n", | |
" \n", | |
" def _crop(self,skip,x):\n", | |
" if self.padding is 0:\n", | |
" if self.crop is None:\n", | |
" self.crop=self.cropping(skip.size()[-1],x.size()[-1])\n", | |
" skip=skip[:,:,self.crop:-self.crop,self.crop:-self.crop]\n", | |
" return skip\n", | |
"\n", | |
" \n", | |
"db_out=UpBlock(256,100)\n", | |
"print(db_out.out_size,db_out.out_ch)\n", | |
"print(db_out(torch.rand(1,256,100,100),torch.rand(1,128,280,280)).shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(4, 2)\n", | |
"(388, 2)\n", | |
"torch.Size([1, 2, 388, 388])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 570, 570] 640\n", | |
" Conv2d-2 [-1, 64, 568, 568] 36,928\n", | |
" BatchNorm2d-3 [-1, 64, 568, 568] 128\n", | |
" AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n", | |
" Linear-5 [-1, 4] 260\n", | |
" ReLU-6 [-1, 4] 0\n", | |
" Linear-7 [-1, 64] 320\n", | |
" Sigmoid-8 [-1, 64] 0\n", | |
" SqueezeExcitation-9 [-1, 64, 568, 568] 0\n", | |
" ConvBlock-10 [-1, 64, 568, 568] 0\n", | |
" MaxPool2d-11 [-1, 64, 284, 284] 0\n", | |
" Conv2d-12 [-1, 128, 282, 282] 73,856\n", | |
" Conv2d-13 [-1, 128, 280, 280] 147,584\n", | |
" BatchNorm2d-14 [-1, 128, 280, 280] 256\n", | |
"AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n", | |
" Linear-16 [-1, 8] 1,032\n", | |
" ReLU-17 [-1, 8] 0\n", | |
" Linear-18 [-1, 128] 1,152\n", | |
" Sigmoid-19 [-1, 128] 0\n", | |
"SqueezeExcitation-20 [-1, 128, 280, 280] 0\n", | |
" ConvBlock-21 [-1, 128, 280, 280] 0\n", | |
" DownBlock-22 [-1, 128, 280, 280] 0\n", | |
" MaxPool2d-23 [-1, 128, 140, 140] 0\n", | |
" Conv2d-24 [-1, 256, 138, 138] 295,168\n", | |
" Conv2d-25 [-1, 256, 136, 136] 590,080\n", | |
" BatchNorm2d-26 [-1, 256, 136, 136] 512\n", | |
"AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n", | |
" Linear-28 [-1, 16] 4,112\n", | |
" ReLU-29 [-1, 16] 0\n", | |
" Linear-30 [-1, 256] 4,352\n", | |
" Sigmoid-31 [-1, 256] 0\n", | |
"SqueezeExcitation-32 [-1, 256, 136, 136] 0\n", | |
" ConvBlock-33 [-1, 256, 136, 136] 0\n", | |
" DownBlock-34 [-1, 256, 136, 136] 0\n", | |
" MaxPool2d-35 [-1, 256, 68, 68] 0\n", | |
" Conv2d-36 [-1, 512, 66, 66] 1,180,160\n", | |
" Conv2d-37 [-1, 512, 64, 64] 2,359,808\n", | |
" BatchNorm2d-38 [-1, 512, 64, 64] 1,024\n", | |
"AdaptiveAvgPool2d-39 [-1, 512, 1, 1] 0\n", | |
" Linear-40 [-1, 32] 16,416\n", | |
" ReLU-41 [-1, 32] 0\n", | |
" Linear-42 [-1, 512] 16,896\n", | |
" Sigmoid-43 [-1, 512] 0\n", | |
"SqueezeExcitation-44 [-1, 512, 64, 64] 0\n", | |
" ConvBlock-45 [-1, 512, 64, 64] 0\n", | |
" DownBlock-46 [-1, 512, 64, 64] 0\n", | |
" MaxPool2d-47 [-1, 512, 32, 32] 0\n", | |
" Conv2d-48 [-1, 1024, 30, 30] 4,719,616\n", | |
" Conv2d-49 [-1, 1024, 28, 28] 9,438,208\n", | |
" BatchNorm2d-50 [-1, 1024, 28, 28] 2,048\n", | |
"AdaptiveAvgPool2d-51 [-1, 1024, 1, 1] 0\n", | |
" Linear-52 [-1, 64] 65,600\n", | |
" ReLU-53 [-1, 64] 0\n", | |
" Linear-54 [-1, 1024] 66,560\n", | |
" Sigmoid-55 [-1, 1024] 0\n", | |
"SqueezeExcitation-56 [-1, 1024, 28, 28] 0\n", | |
" ConvBlock-57 [-1, 1024, 28, 28] 0\n", | |
" DownBlock-58 [-1, 1024, 28, 28] 0\n", | |
" ConvTranspose2d-59 [-1, 512, 56, 56] 2,097,664\n", | |
" Conv2d-60 [-1, 512, 54, 54] 4,719,104\n", | |
" Conv2d-61 [-1, 512, 52, 52] 2,359,808\n", | |
" BatchNorm2d-62 [-1, 512, 52, 52] 1,024\n", | |
"AdaptiveAvgPool2d-63 [-1, 512, 1, 1] 0\n", | |
" Linear-64 [-1, 32] 16,416\n", | |
" ReLU-65 [-1, 32] 0\n", | |
" Linear-66 [-1, 512] 16,896\n", | |
" Sigmoid-67 [-1, 512] 0\n", | |
"SqueezeExcitation-68 [-1, 512, 52, 52] 0\n", | |
" ConvBlock-69 [-1, 512, 52, 52] 0\n", | |
" UpBlock-70 [-1, 512, 52, 52] 0\n", | |
" ConvTranspose2d-71 [-1, 256, 104, 104] 524,544\n", | |
" Conv2d-72 [-1, 256, 102, 102] 1,179,904\n", | |
" Conv2d-73 [-1, 256, 100, 100] 590,080\n", | |
" BatchNorm2d-74 [-1, 256, 100, 100] 512\n", | |
"AdaptiveAvgPool2d-75 [-1, 256, 1, 1] 0\n", | |
" Linear-76 [-1, 16] 4,112\n", | |
" ReLU-77 [-1, 16] 0\n", | |
" Linear-78 [-1, 256] 4,352\n", | |
" Sigmoid-79 [-1, 256] 0\n", | |
"SqueezeExcitation-80 [-1, 256, 100, 100] 0\n", | |
" ConvBlock-81 [-1, 256, 100, 100] 0\n", | |
" UpBlock-82 [-1, 256, 100, 100] 0\n", | |
" ConvTranspose2d-83 [-1, 128, 200, 200] 131,200\n", | |
" Conv2d-84 [-1, 128, 198, 198] 295,040\n", | |
" Conv2d-85 [-1, 128, 196, 196] 147,584\n", | |
" BatchNorm2d-86 [-1, 128, 196, 196] 256\n", | |
"AdaptiveAvgPool2d-87 [-1, 128, 1, 1] 0\n", | |
" Linear-88 [-1, 8] 1,032\n", | |
" ReLU-89 [-1, 8] 0\n", | |
" Linear-90 [-1, 128] 1,152\n", | |
" Sigmoid-91 [-1, 128] 0\n", | |
"SqueezeExcitation-92 [-1, 128, 196, 196] 0\n", | |
" ConvBlock-93 [-1, 128, 196, 196] 0\n", | |
" UpBlock-94 [-1, 128, 196, 196] 0\n", | |
" ConvTranspose2d-95 [-1, 64, 392, 392] 32,832\n", | |
" Conv2d-96 [-1, 64, 390, 390] 73,792\n", | |
" Conv2d-97 [-1, 64, 388, 388] 36,928\n", | |
" BatchNorm2d-98 [-1, 64, 388, 388] 128\n", | |
"AdaptiveAvgPool2d-99 [-1, 64, 1, 1] 0\n", | |
" Linear-100 [-1, 4] 260\n", | |
" ReLU-101 [-1, 4] 0\n", | |
" Linear-102 [-1, 64] 320\n", | |
" Sigmoid-103 [-1, 64] 0\n", | |
"SqueezeExcitation-104 [-1, 64, 388, 388] 0\n", | |
" ConvBlock-105 [-1, 64, 388, 388] 0\n", | |
" UpBlock-106 [-1, 64, 388, 388] 0\n", | |
" Conv2d-107 [-1, 2, 388, 388] 130\n", | |
"================================================================\n", | |
"Total params: 31,257,786\n", | |
"Trainable params: 31,257,786\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"class UNet(nn.Module):\n", | |
"\n", | |
" def __init__(self,\n", | |
" network_depth=4,\n", | |
" conv_depth=2,\n", | |
" in_size=572,\n", | |
" in_ch=1,\n", | |
" out_ch=2,\n", | |
" init_ch=64,\n", | |
" padding=0,\n", | |
" bn=True,\n", | |
" se=True,\n", | |
" act='relu',\n", | |
" act_kwargs={}):\n", | |
" super(UNet, self).__init__()\n", | |
" self.network_depth=network_depth\n", | |
" self.conv_depth=conv_depth\n", | |
" self.out_ch=out_ch\n", | |
" self.padding=padding\n", | |
" self.input_conv=ConvBlock(\n", | |
" in_ch=in_ch,\n", | |
" in_size=in_size,\n", | |
" out_ch=init_ch,\n", | |
" depth=self.conv_depth,\n", | |
" padding=padding,\n", | |
" bn=bn,\n", | |
" se=se,\n", | |
" act=act,\n", | |
" act_kwargs=act_kwargs)\n", | |
" down_layers=self._down_layers(\n", | |
" self.input_conv.out_ch,\n", | |
" self.input_conv.out_size,\n", | |
" bn=bn,\n", | |
" se=se,\n", | |
" act=act,\n", | |
" act_kwargs=act_kwargs)\n", | |
" self.down_blocks=nn.ModuleList(down_layers)\n", | |
" up_layers=self._up_layers(\n", | |
" down_layers,\n", | |
" bn=bn,\n", | |
" se=se,\n", | |
" act=act,\n", | |
" act_kwargs=act_kwargs)\n", | |
" self.up_blocks=nn.ModuleList(up_layers)\n", | |
" self.out_size=self.up_blocks[-1].out_size\n", | |
" self.output_conv=self._output_layer(out_ch)\n", | |
"\n", | |
" \n", | |
" def forward(self, x):\n", | |
" x=self.input_conv(x)\n", | |
" skips=[x]\n", | |
" for block in self.down_blocks:\n", | |
" x=block(x)\n", | |
" skips.append(x)\n", | |
" skips.pop()\n", | |
" skips=skips[::-1]\n", | |
" for skip,block in zip(skips,self.up_blocks):\n", | |
" x=block(x,skip)\n", | |
" x=self.output_conv(x)\n", | |
" return x\n", | |
" \n", | |
" \n", | |
" def _down_layers(self,in_ch,in_size,bn,se,act,act_kwargs):\n", | |
" layers=[]\n", | |
" for index in range(1,self.network_depth+1):\n", | |
" layer=DownBlock(\n", | |
" in_ch,\n", | |
" in_size,\n", | |
" depth=self.conv_depth,\n", | |
" padding=self.padding,\n", | |
" bn=bn,\n", | |
" se=se,\n", | |
" act=act,\n", | |
" act_kwargs=act_kwargs)\n", | |
" in_ch=layer.out_ch\n", | |
" in_size=layer.out_size\n", | |
" layers.append(layer)\n", | |
" return layers\n", | |
"\n", | |
" \n", | |
" def _up_layers(self,down_layers,bn,se,act,act_kwargs):\n", | |
" down_layers=down_layers[::-1]\n", | |
" down_layers.append(self.input_conv)\n", | |
" first=down_layers.pop(0)\n", | |
" in_ch=first.out_ch\n", | |
" in_size=first.out_size\n", | |
" layers=[]\n", | |
" for down_layer in down_layers:\n", | |
" crop=UpBlock.cropping(down_layer.out_size,2*in_size)\n", | |
" layer=UpBlock(\n", | |
" in_ch,\n", | |
" in_size,\n", | |
" depth=self.conv_depth,\n", | |
" crop=crop,\n", | |
" padding=self.padding,\n", | |
" bn=bn,\n", | |
" se=se,\n", | |
" act=act,\n", | |
" act_kwargs=act_kwargs)\n", | |
" in_ch=layer.out_ch\n", | |
" in_size=layer.out_size\n", | |
" layers.append(layer)\n", | |
" return layers\n", | |
"\n", | |
" \n", | |
" def _output_layer(self,out_ch):\n", | |
" return nn.Conv2d(\n", | |
" in_channels=64,\n", | |
" out_channels=out_ch,\n", | |
" kernel_size=1,\n", | |
" stride=1,\n", | |
" padding=0)\n", | |
" \n", | |
" \n", | |
"unet=UNet(in_size=572,network_depth=4,conv_depth=2)\n", | |
"print(unet.network_depth,unet.conv_depth)\n", | |
"print(unet.out_size,unet.out_ch)\n", | |
"print(unet(torch.rand(1,1,572,572)).shape)\n", | |
"summary(unet,input_size=(1,572,572))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(2, 4)\n", | |
"(492, 2)\n", | |
"torch.Size([1, 2, 492, 492])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 570, 570] 640\n", | |
" Conv2d-2 [-1, 64, 568, 568] 36,928\n", | |
" Conv2d-3 [-1, 64, 566, 566] 36,928\n", | |
" Conv2d-4 [-1, 64, 564, 564] 36,928\n", | |
" BatchNorm2d-5 [-1, 64, 564, 564] 128\n", | |
" AdaptiveAvgPool2d-6 [-1, 64, 1, 1] 0\n", | |
" Linear-7 [-1, 4] 260\n", | |
" ReLU-8 [-1, 4] 0\n", | |
" Linear-9 [-1, 64] 320\n", | |
" Sigmoid-10 [-1, 64] 0\n", | |
"SqueezeExcitation-11 [-1, 64, 564, 564] 0\n", | |
" ConvBlock-12 [-1, 64, 564, 564] 0\n", | |
" MaxPool2d-13 [-1, 64, 282, 282] 0\n", | |
" Conv2d-14 [-1, 128, 280, 280] 73,856\n", | |
" Conv2d-15 [-1, 128, 278, 278] 147,584\n", | |
" Conv2d-16 [-1, 128, 276, 276] 147,584\n", | |
" Conv2d-17 [-1, 128, 274, 274] 147,584\n", | |
" BatchNorm2d-18 [-1, 128, 274, 274] 256\n", | |
"AdaptiveAvgPool2d-19 [-1, 128, 1, 1] 0\n", | |
" Linear-20 [-1, 8] 1,032\n", | |
" ReLU-21 [-1, 8] 0\n", | |
" Linear-22 [-1, 128] 1,152\n", | |
" Sigmoid-23 [-1, 128] 0\n", | |
"SqueezeExcitation-24 [-1, 128, 274, 274] 0\n", | |
" ConvBlock-25 [-1, 128, 274, 274] 0\n", | |
" DownBlock-26 [-1, 128, 274, 274] 0\n", | |
" MaxPool2d-27 [-1, 128, 137, 137] 0\n", | |
" Conv2d-28 [-1, 256, 135, 135] 295,168\n", | |
" Conv2d-29 [-1, 256, 133, 133] 590,080\n", | |
" Conv2d-30 [-1, 256, 131, 131] 590,080\n", | |
" Conv2d-31 [-1, 256, 129, 129] 590,080\n", | |
" BatchNorm2d-32 [-1, 256, 129, 129] 512\n", | |
"AdaptiveAvgPool2d-33 [-1, 256, 1, 1] 0\n", | |
" Linear-34 [-1, 16] 4,112\n", | |
" ReLU-35 [-1, 16] 0\n", | |
" Linear-36 [-1, 256] 4,352\n", | |
" Sigmoid-37 [-1, 256] 0\n", | |
"SqueezeExcitation-38 [-1, 256, 129, 129] 0\n", | |
" ConvBlock-39 [-1, 256, 129, 129] 0\n", | |
" DownBlock-40 [-1, 256, 129, 129] 0\n", | |
" ConvTranspose2d-41 [-1, 128, 258, 258] 131,200\n", | |
" Conv2d-42 [-1, 128, 256, 256] 295,040\n", | |
" Conv2d-43 [-1, 128, 254, 254] 147,584\n", | |
" Conv2d-44 [-1, 128, 252, 252] 147,584\n", | |
" Conv2d-45 [-1, 128, 250, 250] 147,584\n", | |
" BatchNorm2d-46 [-1, 128, 250, 250] 256\n", | |
"AdaptiveAvgPool2d-47 [-1, 128, 1, 1] 0\n", | |
" Linear-48 [-1, 8] 1,032\n", | |
" ReLU-49 [-1, 8] 0\n", | |
" Linear-50 [-1, 128] 1,152\n", | |
" Sigmoid-51 [-1, 128] 0\n", | |
"SqueezeExcitation-52 [-1, 128, 250, 250] 0\n", | |
" ConvBlock-53 [-1, 128, 250, 250] 0\n", | |
" UpBlock-54 [-1, 128, 250, 250] 0\n", | |
" ConvTranspose2d-55 [-1, 64, 500, 500] 32,832\n", | |
" Conv2d-56 [-1, 64, 498, 498] 73,792\n", | |
" Conv2d-57 [-1, 64, 496, 496] 36,928\n", | |
" Conv2d-58 [-1, 64, 494, 494] 36,928\n", | |
" Conv2d-59 [-1, 64, 492, 492] 36,928\n", | |
" BatchNorm2d-60 [-1, 64, 492, 492] 128\n", | |
"AdaptiveAvgPool2d-61 [-1, 64, 1, 1] 0\n", | |
" Linear-62 [-1, 4] 260\n", | |
" ReLU-63 [-1, 4] 0\n", | |
" Linear-64 [-1, 64] 320\n", | |
" Sigmoid-65 [-1, 64] 0\n", | |
"SqueezeExcitation-66 [-1, 64, 492, 492] 0\n", | |
" ConvBlock-67 [-1, 64, 492, 492] 0\n", | |
" UpBlock-68 [-1, 64, 492, 492] 0\n", | |
" Conv2d-69 [-1, 2, 492, 492] 130\n", | |
"================================================================\n", | |
"Total params: 3,795,242\n", | |
"Trainable params: 3,795,242\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"unet=UNet(in_size=572,network_depth=2,conv_depth=4)\n", | |
"print(unet.network_depth,unet.conv_depth)\n", | |
"print(unet.out_size,unet.out_ch)\n", | |
"print(unet(torch.rand(1,1,572,572)).shape)\n", | |
"summary(unet,input_size=(1,572,572))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(2, 2)\n", | |
"(216, 2)\n", | |
"torch.Size([1, 2, 216, 216])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 254, 254] 640\n", | |
" Conv2d-2 [-1, 64, 252, 252] 36,928\n", | |
" BatchNorm2d-3 [-1, 64, 252, 252] 128\n", | |
" AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n", | |
" Linear-5 [-1, 4] 260\n", | |
" ReLU-6 [-1, 4] 0\n", | |
" Linear-7 [-1, 64] 320\n", | |
" Sigmoid-8 [-1, 64] 0\n", | |
" SqueezeExcitation-9 [-1, 64, 252, 252] 0\n", | |
" ConvBlock-10 [-1, 64, 252, 252] 0\n", | |
" MaxPool2d-11 [-1, 64, 126, 126] 0\n", | |
" Conv2d-12 [-1, 128, 124, 124] 73,856\n", | |
" Conv2d-13 [-1, 128, 122, 122] 147,584\n", | |
" BatchNorm2d-14 [-1, 128, 122, 122] 256\n", | |
"AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n", | |
" Linear-16 [-1, 8] 1,032\n", | |
" ReLU-17 [-1, 8] 0\n", | |
" Linear-18 [-1, 128] 1,152\n", | |
" Sigmoid-19 [-1, 128] 0\n", | |
"SqueezeExcitation-20 [-1, 128, 122, 122] 0\n", | |
" ConvBlock-21 [-1, 128, 122, 122] 0\n", | |
" DownBlock-22 [-1, 128, 122, 122] 0\n", | |
" MaxPool2d-23 [-1, 128, 61, 61] 0\n", | |
" Conv2d-24 [-1, 256, 59, 59] 295,168\n", | |
" Conv2d-25 [-1, 256, 57, 57] 590,080\n", | |
" BatchNorm2d-26 [-1, 256, 57, 57] 512\n", | |
"AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n", | |
" Linear-28 [-1, 16] 4,112\n", | |
" ReLU-29 [-1, 16] 0\n", | |
" Linear-30 [-1, 256] 4,352\n", | |
" Sigmoid-31 [-1, 256] 0\n", | |
"SqueezeExcitation-32 [-1, 256, 57, 57] 0\n", | |
" ConvBlock-33 [-1, 256, 57, 57] 0\n", | |
" DownBlock-34 [-1, 256, 57, 57] 0\n", | |
" ConvTranspose2d-35 [-1, 128, 114, 114] 131,200\n", | |
" Conv2d-36 [-1, 128, 112, 112] 295,040\n", | |
" Conv2d-37 [-1, 128, 110, 110] 147,584\n", | |
" BatchNorm2d-38 [-1, 128, 110, 110] 256\n", | |
"AdaptiveAvgPool2d-39 [-1, 128, 1, 1] 0\n", | |
" Linear-40 [-1, 8] 1,032\n", | |
" ReLU-41 [-1, 8] 0\n", | |
" Linear-42 [-1, 128] 1,152\n", | |
" Sigmoid-43 [-1, 128] 0\n", | |
"SqueezeExcitation-44 [-1, 128, 110, 110] 0\n", | |
" ConvBlock-45 [-1, 128, 110, 110] 0\n", | |
" UpBlock-46 [-1, 128, 110, 110] 0\n", | |
" ConvTranspose2d-47 [-1, 64, 220, 220] 32,832\n", | |
" Conv2d-48 [-1, 64, 218, 218] 73,792\n", | |
" Conv2d-49 [-1, 64, 216, 216] 36,928\n", | |
" BatchNorm2d-50 [-1, 64, 216, 216] 128\n", | |
"AdaptiveAvgPool2d-51 [-1, 64, 1, 1] 0\n", | |
" Linear-52 [-1, 4] 260\n", | |
" ReLU-53 [-1, 4] 0\n", | |
" Linear-54 [-1, 64] 320\n", | |
" Sigmoid-55 [-1, 64] 0\n", | |
"SqueezeExcitation-56 [-1, 64, 216, 216] 0\n", | |
" ConvBlock-57 [-1, 64, 216, 216] 0\n", | |
" UpBlock-58 [-1, 64, 216, 216] 0\n", | |
" Conv2d-59 [-1, 2, 216, 216] 130\n", | |
"================================================================\n", | |
"Total params: 1,877,034\n", | |
"Trainable params: 1,877,034\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"SIZE=256\n", | |
"unet=UNet(in_size=SIZE,network_depth=2)\n", | |
"print(unet.network_depth,unet.conv_depth)\n", | |
"print(unet.out_size,unet.out_ch)\n", | |
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
"summary(unet,input_size=(1,SIZE,SIZE))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(5, 2)\n", | |
"(256, 2)\n", | |
"torch.Size([1, 2, 256, 256])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 256, 256] 640\n", | |
" Conv2d-2 [-1, 64, 256, 256] 36,928\n", | |
" BatchNorm2d-3 [-1, 64, 256, 256] 128\n", | |
" AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n", | |
" Linear-5 [-1, 4] 260\n", | |
" ReLU-6 [-1, 4] 0\n", | |
" Linear-7 [-1, 64] 320\n", | |
" Sigmoid-8 [-1, 64] 0\n", | |
" SqueezeExcitation-9 [-1, 64, 256, 256] 0\n", | |
" ConvBlock-10 [-1, 64, 256, 256] 0\n", | |
" MaxPool2d-11 [-1, 64, 128, 128] 0\n", | |
" Conv2d-12 [-1, 128, 128, 128] 73,856\n", | |
" Conv2d-13 [-1, 128, 128, 128] 147,584\n", | |
" BatchNorm2d-14 [-1, 128, 128, 128] 256\n", | |
"AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n", | |
" Linear-16 [-1, 8] 1,032\n", | |
" ReLU-17 [-1, 8] 0\n", | |
" Linear-18 [-1, 128] 1,152\n", | |
" Sigmoid-19 [-1, 128] 0\n", | |
"SqueezeExcitation-20 [-1, 128, 128, 128] 0\n", | |
" ConvBlock-21 [-1, 128, 128, 128] 0\n", | |
" DownBlock-22 [-1, 128, 128, 128] 0\n", | |
" MaxPool2d-23 [-1, 128, 64, 64] 0\n", | |
" Conv2d-24 [-1, 256, 64, 64] 295,168\n", | |
" Conv2d-25 [-1, 256, 64, 64] 590,080\n", | |
" BatchNorm2d-26 [-1, 256, 64, 64] 512\n", | |
"AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n", | |
" Linear-28 [-1, 16] 4,112\n", | |
" ReLU-29 [-1, 16] 0\n", | |
" Linear-30 [-1, 256] 4,352\n", | |
" Sigmoid-31 [-1, 256] 0\n", | |
"SqueezeExcitation-32 [-1, 256, 64, 64] 0\n", | |
" ConvBlock-33 [-1, 256, 64, 64] 0\n", | |
" DownBlock-34 [-1, 256, 64, 64] 0\n", | |
" MaxPool2d-35 [-1, 256, 32, 32] 0\n", | |
" Conv2d-36 [-1, 512, 32, 32] 1,180,160\n", | |
" Conv2d-37 [-1, 512, 32, 32] 2,359,808\n", | |
" BatchNorm2d-38 [-1, 512, 32, 32] 1,024\n", | |
"AdaptiveAvgPool2d-39 [-1, 512, 1, 1] 0\n", | |
" Linear-40 [-1, 32] 16,416\n", | |
" ReLU-41 [-1, 32] 0\n", | |
" Linear-42 [-1, 512] 16,896\n", | |
" Sigmoid-43 [-1, 512] 0\n", | |
"SqueezeExcitation-44 [-1, 512, 32, 32] 0\n", | |
" ConvBlock-45 [-1, 512, 32, 32] 0\n", | |
" DownBlock-46 [-1, 512, 32, 32] 0\n", | |
" MaxPool2d-47 [-1, 512, 16, 16] 0\n", | |
" Conv2d-48 [-1, 1024, 16, 16] 4,719,616\n", | |
" Conv2d-49 [-1, 1024, 16, 16] 9,438,208\n", | |
" BatchNorm2d-50 [-1, 1024, 16, 16] 2,048\n", | |
"AdaptiveAvgPool2d-51 [-1, 1024, 1, 1] 0\n", | |
" Linear-52 [-1, 64] 65,600\n", | |
" ReLU-53 [-1, 64] 0\n", | |
" Linear-54 [-1, 1024] 66,560\n", | |
" Sigmoid-55 [-1, 1024] 0\n", | |
"SqueezeExcitation-56 [-1, 1024, 16, 16] 0\n", | |
" ConvBlock-57 [-1, 1024, 16, 16] 0\n", | |
" DownBlock-58 [-1, 1024, 16, 16] 0\n", | |
" MaxPool2d-59 [-1, 1024, 8, 8] 0\n", | |
" Conv2d-60 [-1, 2048, 8, 8] 18,876,416\n", | |
" Conv2d-61 [-1, 2048, 8, 8] 37,750,784\n", | |
" BatchNorm2d-62 [-1, 2048, 8, 8] 4,096\n", | |
"AdaptiveAvgPool2d-63 [-1, 2048, 1, 1] 0\n", | |
" Linear-64 [-1, 128] 262,272\n", | |
" ReLU-65 [-1, 128] 0\n", | |
" Linear-66 [-1, 2048] 264,192\n", | |
" Sigmoid-67 [-1, 2048] 0\n", | |
"SqueezeExcitation-68 [-1, 2048, 8, 8] 0\n", | |
" ConvBlock-69 [-1, 2048, 8, 8] 0\n", | |
" DownBlock-70 [-1, 2048, 8, 8] 0\n", | |
" ConvTranspose2d-71 [-1, 1024, 16, 16] 8,389,632\n", | |
" Conv2d-72 [-1, 1024, 16, 16] 18,875,392\n", | |
" Conv2d-73 [-1, 1024, 16, 16] 9,438,208\n", | |
" BatchNorm2d-74 [-1, 1024, 16, 16] 2,048\n", | |
"AdaptiveAvgPool2d-75 [-1, 1024, 1, 1] 0\n", | |
" Linear-76 [-1, 64] 65,600\n", | |
" ReLU-77 [-1, 64] 0\n", | |
" Linear-78 [-1, 1024] 66,560\n", | |
" Sigmoid-79 [-1, 1024] 0\n", | |
"SqueezeExcitation-80 [-1, 1024, 16, 16] 0\n", | |
" ConvBlock-81 [-1, 1024, 16, 16] 0\n", | |
" UpBlock-82 [-1, 1024, 16, 16] 0\n", | |
" ConvTranspose2d-83 [-1, 512, 32, 32] 2,097,664\n", | |
" Conv2d-84 [-1, 512, 32, 32] 4,719,104\n", | |
" Conv2d-85 [-1, 512, 32, 32] 2,359,808\n", | |
" BatchNorm2d-86 [-1, 512, 32, 32] 1,024\n", | |
"AdaptiveAvgPool2d-87 [-1, 512, 1, 1] 0\n", | |
" Linear-88 [-1, 32] 16,416\n", | |
" ReLU-89 [-1, 32] 0\n", | |
" Linear-90 [-1, 512] 16,896\n", | |
" Sigmoid-91 [-1, 512] 0\n", | |
"SqueezeExcitation-92 [-1, 512, 32, 32] 0\n", | |
" ConvBlock-93 [-1, 512, 32, 32] 0\n", | |
" UpBlock-94 [-1, 512, 32, 32] 0\n", | |
" ConvTranspose2d-95 [-1, 256, 64, 64] 524,544\n", | |
" Conv2d-96 [-1, 256, 64, 64] 1,179,904\n", | |
" Conv2d-97 [-1, 256, 64, 64] 590,080\n", | |
" BatchNorm2d-98 [-1, 256, 64, 64] 512\n", | |
"AdaptiveAvgPool2d-99 [-1, 256, 1, 1] 0\n", | |
" Linear-100 [-1, 16] 4,112\n", | |
" ReLU-101 [-1, 16] 0\n", | |
" Linear-102 [-1, 256] 4,352\n", | |
" Sigmoid-103 [-1, 256] 0\n", | |
"SqueezeExcitation-104 [-1, 256, 64, 64] 0\n", | |
" ConvBlock-105 [-1, 256, 64, 64] 0\n", | |
" UpBlock-106 [-1, 256, 64, 64] 0\n", | |
" ConvTranspose2d-107 [-1, 128, 128, 128] 131,200\n", | |
" Conv2d-108 [-1, 128, 128, 128] 295,040\n", | |
" Conv2d-109 [-1, 128, 128, 128] 147,584\n", | |
" BatchNorm2d-110 [-1, 128, 128, 128] 256\n", | |
"AdaptiveAvgPool2d-111 [-1, 128, 1, 1] 0\n", | |
" Linear-112 [-1, 8] 1,032\n", | |
" ReLU-113 [-1, 8] 0\n", | |
" Linear-114 [-1, 128] 1,152\n", | |
" Sigmoid-115 [-1, 128] 0\n", | |
"SqueezeExcitation-116 [-1, 128, 128, 128] 0\n", | |
" ConvBlock-117 [-1, 128, 128, 128] 0\n", | |
" UpBlock-118 [-1, 128, 128, 128] 0\n", | |
" ConvTranspose2d-119 [-1, 64, 256, 256] 32,832\n", | |
" Conv2d-120 [-1, 64, 256, 256] 73,792\n", | |
" Conv2d-121 [-1, 64, 256, 256] 36,928\n", | |
" BatchNorm2d-122 [-1, 64, 256, 256] 128\n", | |
"AdaptiveAvgPool2d-123 [-1, 64, 1, 1] 0\n", | |
" Linear-124 [-1, 4] 260\n", | |
" ReLU-125 [-1, 4] 0\n", | |
" Linear-126 [-1, 64] 320\n", | |
" Sigmoid-127 [-1, 64] 0\n", | |
"SqueezeExcitation-128 [-1, 64, 256, 256] 0\n", | |
" ConvBlock-129 [-1, 64, 256, 256] 0\n", | |
" UpBlock-130 [-1, 64, 256, 256] 0\n", | |
" Conv2d-131 [-1, 2, 256, 256] 130\n", | |
"================================================================\n", | |
"Total params: 125,252,986\n", | |
"Trainable params: 125,252,986\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"SIZE=256\n", | |
"unet=UNet(in_size=SIZE,network_depth=5,padding=1)\n", | |
"print(unet.network_depth,unet.conv_depth)\n", | |
"print(unet.out_size,unet.out_ch)\n", | |
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
"summary(unet,input_size=(1,SIZE,SIZE))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(2, 2)\n", | |
"(216, 2)\n", | |
"torch.Size([1, 2, 216, 216])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 254, 254] 640\n", | |
" Conv2d-2 [-1, 64, 252, 252] 36,928\n", | |
" ConvBlock-3 [-1, 64, 252, 252] 0\n", | |
" MaxPool2d-4 [-1, 64, 126, 126] 0\n", | |
" Conv2d-5 [-1, 128, 124, 124] 73,856\n", | |
" Conv2d-6 [-1, 128, 122, 122] 147,584\n", | |
" ConvBlock-7 [-1, 128, 122, 122] 0\n", | |
" DownBlock-8 [-1, 128, 122, 122] 0\n", | |
" MaxPool2d-9 [-1, 128, 61, 61] 0\n", | |
" Conv2d-10 [-1, 256, 59, 59] 295,168\n", | |
" Conv2d-11 [-1, 256, 57, 57] 590,080\n", | |
" ConvBlock-12 [-1, 256, 57, 57] 0\n", | |
" DownBlock-13 [-1, 256, 57, 57] 0\n", | |
" ConvTranspose2d-14 [-1, 128, 114, 114] 131,200\n", | |
" Conv2d-15 [-1, 128, 112, 112] 295,040\n", | |
" Conv2d-16 [-1, 128, 110, 110] 147,584\n", | |
" ConvBlock-17 [-1, 128, 110, 110] 0\n", | |
" UpBlock-18 [-1, 128, 110, 110] 0\n", | |
" ConvTranspose2d-19 [-1, 64, 220, 220] 32,832\n", | |
" Conv2d-20 [-1, 64, 218, 218] 73,792\n", | |
" Conv2d-21 [-1, 64, 216, 216] 36,928\n", | |
" ConvBlock-22 [-1, 64, 216, 216] 0\n", | |
" UpBlock-23 [-1, 64, 216, 216] 0\n", | |
" Conv2d-24 [-1, 2, 216, 216] 130\n", | |
"================================================================\n", | |
"Total params: 1,861,762\n", | |
"Trainable params: 1,861,762\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"SIZE=256\n", | |
"unet=UNet(in_size=SIZE,network_depth=2,bn=False,se=False)\n", | |
"print(unet.network_depth,unet.conv_depth)\n", | |
"print(unet.out_size,unet.out_ch)\n", | |
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
"summary(unet,input_size=(1,SIZE,SIZE))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(2, 2)\n", | |
"(216, 2)\n", | |
"torch.Size([1, 2, 216, 216])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 254, 254] 640\n", | |
" Conv2d-2 [-1, 64, 252, 252] 36,928\n", | |
" AdaptiveAvgPool2d-3 [-1, 64, 1, 1] 0\n", | |
" Linear-4 [-1, 4] 260\n", | |
" ReLU-5 [-1, 4] 0\n", | |
" Linear-6 [-1, 64] 320\n", | |
" Sigmoid-7 [-1, 64] 0\n", | |
" SqueezeExcitation-8 [-1, 64, 252, 252] 0\n", | |
" ConvBlock-9 [-1, 64, 252, 252] 0\n", | |
" MaxPool2d-10 [-1, 64, 126, 126] 0\n", | |
" Conv2d-11 [-1, 128, 124, 124] 73,856\n", | |
" Conv2d-12 [-1, 128, 122, 122] 147,584\n", | |
"AdaptiveAvgPool2d-13 [-1, 128, 1, 1] 0\n", | |
" Linear-14 [-1, 8] 1,032\n", | |
" ReLU-15 [-1, 8] 0\n", | |
" Linear-16 [-1, 128] 1,152\n", | |
" Sigmoid-17 [-1, 128] 0\n", | |
"SqueezeExcitation-18 [-1, 128, 122, 122] 0\n", | |
" ConvBlock-19 [-1, 128, 122, 122] 0\n", | |
" DownBlock-20 [-1, 128, 122, 122] 0\n", | |
" MaxPool2d-21 [-1, 128, 61, 61] 0\n", | |
" Conv2d-22 [-1, 256, 59, 59] 295,168\n", | |
" Conv2d-23 [-1, 256, 57, 57] 590,080\n", | |
"AdaptiveAvgPool2d-24 [-1, 256, 1, 1] 0\n", | |
" Linear-25 [-1, 16] 4,112\n", | |
" ReLU-26 [-1, 16] 0\n", | |
" Linear-27 [-1, 256] 4,352\n", | |
" Sigmoid-28 [-1, 256] 0\n", | |
"SqueezeExcitation-29 [-1, 256, 57, 57] 0\n", | |
" ConvBlock-30 [-1, 256, 57, 57] 0\n", | |
" DownBlock-31 [-1, 256, 57, 57] 0\n", | |
" ConvTranspose2d-32 [-1, 128, 114, 114] 131,200\n", | |
" Conv2d-33 [-1, 128, 112, 112] 295,040\n", | |
" Conv2d-34 [-1, 128, 110, 110] 147,584\n", | |
"AdaptiveAvgPool2d-35 [-1, 128, 1, 1] 0\n", | |
" Linear-36 [-1, 8] 1,032\n", | |
" ReLU-37 [-1, 8] 0\n", | |
" Linear-38 [-1, 128] 1,152\n", | |
" Sigmoid-39 [-1, 128] 0\n", | |
"SqueezeExcitation-40 [-1, 128, 110, 110] 0\n", | |
" ConvBlock-41 [-1, 128, 110, 110] 0\n", | |
" UpBlock-42 [-1, 128, 110, 110] 0\n", | |
" ConvTranspose2d-43 [-1, 64, 220, 220] 32,832\n", | |
" Conv2d-44 [-1, 64, 218, 218] 73,792\n", | |
" Conv2d-45 [-1, 64, 216, 216] 36,928\n", | |
"AdaptiveAvgPool2d-46 [-1, 64, 1, 1] 0\n", | |
" Linear-47 [-1, 4] 260\n", | |
" ReLU-48 [-1, 4] 0\n", | |
" Linear-49 [-1, 64] 320\n", | |
" Sigmoid-50 [-1, 64] 0\n", | |
"SqueezeExcitation-51 [-1, 64, 216, 216] 0\n", | |
" ConvBlock-52 [-1, 64, 216, 216] 0\n", | |
" UpBlock-53 [-1, 64, 216, 216] 0\n", | |
" Conv2d-54 [-1, 2, 216, 216] 130\n", | |
"================================================================\n", | |
"Total params: 1,875,754\n", | |
"Trainable params: 1,875,754\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"SIZE=256\n", | |
"unet=UNet(in_size=SIZE,network_depth=2,bn=False,se=True)\n", | |
"print(unet.network_depth,unet.conv_depth)\n", | |
"print(unet.out_size,unet.out_ch)\n", | |
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
"summary(unet,input_size=(1,SIZE,SIZE))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(2, 2)\n", | |
"(216, 2)\n", | |
"torch.Size([1, 2, 216, 216])\n", | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" Conv2d-1 [-1, 64, 254, 254] 640\n", | |
" Conv2d-2 [-1, 64, 252, 252] 36,928\n", | |
" BatchNorm2d-3 [-1, 64, 252, 252] 128\n", | |
" ConvBlock-4 [-1, 64, 252, 252] 0\n", | |
" MaxPool2d-5 [-1, 64, 126, 126] 0\n", | |
" Conv2d-6 [-1, 128, 124, 124] 73,856\n", | |
" Conv2d-7 [-1, 128, 122, 122] 147,584\n", | |
" BatchNorm2d-8 [-1, 128, 122, 122] 256\n", | |
" ConvBlock-9 [-1, 128, 122, 122] 0\n", | |
" DownBlock-10 [-1, 128, 122, 122] 0\n", | |
" MaxPool2d-11 [-1, 128, 61, 61] 0\n", | |
" Conv2d-12 [-1, 256, 59, 59] 295,168\n", | |
" Conv2d-13 [-1, 256, 57, 57] 590,080\n", | |
" BatchNorm2d-14 [-1, 256, 57, 57] 512\n", | |
" ConvBlock-15 [-1, 256, 57, 57] 0\n", | |
" DownBlock-16 [-1, 256, 57, 57] 0\n", | |
" ConvTranspose2d-17 [-1, 128, 114, 114] 131,200\n", | |
" Conv2d-18 [-1, 128, 112, 112] 295,040\n", | |
" Conv2d-19 [-1, 128, 110, 110] 147,584\n", | |
" BatchNorm2d-20 [-1, 128, 110, 110] 256\n", | |
" ConvBlock-21 [-1, 128, 110, 110] 0\n", | |
" UpBlock-22 [-1, 128, 110, 110] 0\n", | |
" ConvTranspose2d-23 [-1, 64, 220, 220] 32,832\n", | |
" Conv2d-24 [-1, 64, 218, 218] 73,792\n", | |
" Conv2d-25 [-1, 64, 216, 216] 36,928\n", | |
" BatchNorm2d-26 [-1, 64, 216, 216] 128\n", | |
" ConvBlock-27 [-1, 64, 216, 216] 0\n", | |
" UpBlock-28 [-1, 64, 216, 216] 0\n", | |
" Conv2d-29 [-1, 2, 216, 216] 130\n", | |
"================================================================\n", | |
"Total params: 1,863,042\n", | |
"Trainable params: 1,863,042\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"SIZE=256\n", | |
"unet=UNet(in_size=SIZE,network_depth=2,act='elu',se=False)\n", | |
"print(unet.network_depth,unet.conv_depth)\n", | |
"print(unet.out_size,unet.out_ch)\n", | |
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
"summary(unet,input_size=(1,SIZE,SIZE))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Author
brookisme
commented
Jul 20, 2018
•
- build unet with paper parameters
- optional depth of unet
- optional depth of conv_blocks
- valid/same padding
- other conv-block kwargs: se, bn, act
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment