Skip to content

Instantly share code, notes, and snippets.

@brookisme
Last active October 24, 2022 16:01
Show Gist options
  • Save brookisme/52079e106255f75c996d8595cd3988b0 to your computer and use it in GitHub Desktop.
Save brookisme/52079e106255f75c996d8595cd3988b0 to your computer and use it in GitHub Desktop.
UNET with Squeeze and Excitation Blocks
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torchsummary import summary"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 64, 568, 568])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" AdaptiveAvgPool2d-1 [-1, 64, 1, 1] 0\n",
" Linear-2 [-1, 4] 260\n",
" ReLU-3 [-1, 4] 0\n",
" Linear-4 [-1, 64] 320\n",
" Sigmoid-5 [-1, 64] 0\n",
"================================================================\n",
"Total params: 580\n",
"Trainable params: 580\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"class SqueezeExcitation(nn.Module):\n",
" def __init__(self, nb_channels, reduction=16):\n",
" super(SqueezeExcitation, self).__init__()\n",
" self.nb_channels=nb_channels\n",
" self.avg_pool=nn.AdaptiveAvgPool2d(1)\n",
" self.fc=nn.Sequential(\n",
" nn.Linear(nb_channels, nb_channels // reduction),\n",
" nn.ReLU(inplace=True),\n",
" nn.Linear(nb_channels // reduction, nb_channels),\n",
" nn.Sigmoid())\n",
"\n",
" \n",
" def forward(self, x):\n",
" y = self.avg_pool(x).view(-1,self.nb_channels)\n",
" y = self.fc(y).view(-1,self.nb_channels,1,1)\n",
" return x * y\n",
" \n",
"\n",
"print(SqueezeExcitation(64)(torch.rand(64,568,568)).shape)\n",
"summary(SqueezeExcitation(64),input_size=(64,568,568))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(568, 64)\n",
"torch.Size([1, 64, 568, 568])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 570, 570] 640\n",
" Conv2d-2 [-1, 64, 568, 568] 36,928\n",
" BatchNorm2d-3 [-1, 64, 568, 568] 128\n",
" AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n",
" Linear-5 [-1, 4] 260\n",
" ReLU-6 [-1, 4] 0\n",
" Linear-7 [-1, 64] 320\n",
" Sigmoid-8 [-1, 64] 0\n",
" SqueezeExcitation-9 [-1, 64, 568, 568] 0\n",
"================================================================\n",
"Total params: 38,276\n",
"Trainable params: 38,276\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"class ConvBlock(nn.Module):\n",
"\n",
" def __init__(self,\n",
" in_ch,\n",
" in_size,\n",
" depth=2, \n",
" kernel_size=3, \n",
" stride=1, \n",
" padding=0, \n",
" out_ch=None,\n",
" bn=True,\n",
" se=True,\n",
" act='relu',\n",
" act_kwargs={}):\n",
" super(ConvBlock, self).__init__()\n",
" self.out_ch=out_ch or 2*in_ch\n",
" self._set_post_processes(self.out_ch,bn,se,act,act_kwargs)\n",
" self._set_conv_layers(\n",
" depth,\n",
" in_ch,\n",
" kernel_size,\n",
" stride,\n",
" padding)\n",
" self.out_size=in_size-depth*2*((kernel_size-1)/2-padding)\n",
"\n",
" \n",
" def forward(self, x):\n",
" x=self.conv_layers(x)\n",
" if self.bn:\n",
" x=self.bn(x)\n",
" if self.act:\n",
" x=self._activation(x)\n",
" if self.se:\n",
" x=self.se(x)\n",
" return x\n",
"\n",
" \n",
" def _set_post_processes(self,out_channels,bn,se,act,act_kwargs):\n",
" if bn:\n",
" self.bn=nn.BatchNorm2d(out_channels)\n",
" else:\n",
" self.bn=False\n",
" if se:\n",
" self.se=SqueezeExcitation(out_channels)\n",
" else:\n",
" self.se=False\n",
" self.act=act\n",
" self.act_kwargs=act_kwargs\n",
"\n",
" \n",
" def _set_conv_layers(\n",
" self,\n",
" depth,\n",
" in_ch,\n",
" kernel_size,\n",
" stride,\n",
" padding):\n",
" layers=[]\n",
" for index in range(depth):\n",
" if index!=0:\n",
" in_ch=self.out_ch\n",
" layers.append(\n",
" nn.Conv2d(\n",
" in_channels=in_ch,\n",
" out_channels=self.out_ch,\n",
" kernel_size=kernel_size,\n",
" stride=stride,\n",
" padding=padding))\n",
" self.conv_layers=nn.Sequential(*layers)\n",
"\n",
" \n",
" def _activation(self,x):\n",
" return getattr(F,self.act,**self.act_kwargs)(x)\n",
"\n",
" \n",
"conv_block=ConvBlock(1,572,out_ch=64)\n",
"print(conv_block.out_size,conv_block.out_ch)\n",
"print(conv_block(torch.rand(1,1,572,572)).shape)\n",
"summary(ConvBlock(1,572,out_ch=64),input_size=(1,572,572))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(276, 128)\n",
"torch.Size([1, 128, 276, 276])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" MaxPool2d-1 [-1, 64, 284, 284] 0\n",
" Conv2d-2 [-1, 128, 282, 282] 73,856\n",
" Conv2d-3 [-1, 128, 280, 280] 147,584\n",
" Conv2d-4 [-1, 128, 278, 278] 147,584\n",
" Conv2d-5 [-1, 128, 276, 276] 147,584\n",
" BatchNorm2d-6 [-1, 128, 276, 276] 256\n",
" AdaptiveAvgPool2d-7 [-1, 128, 1, 1] 0\n",
" Linear-8 [-1, 8] 1,032\n",
" ReLU-9 [-1, 8] 0\n",
" Linear-10 [-1, 128] 1,152\n",
" Sigmoid-11 [-1, 128] 0\n",
"SqueezeExcitation-12 [-1, 128, 276, 276] 0\n",
" ConvBlock-13 [-1, 128, 276, 276] 0\n",
"================================================================\n",
"Total params: 519,048\n",
"Trainable params: 519,048\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"class DownBlock(nn.Module):\n",
" \n",
" def __init__(self,\n",
" in_ch,\n",
" in_size,\n",
" out_ch=None,\n",
" depth=2,\n",
" padding=0,\n",
" bn=True,\n",
" se=True,\n",
" act='relu',\n",
" act_kwargs={}):\n",
" super(DownBlock, self).__init__()\n",
" self.out_size=(in_size//2)-depth*(1-padding)*2\n",
" self.out_ch=out_ch or in_ch*2\n",
" self.down=nn.MaxPool2d(kernel_size=2)\n",
" self.conv_block=ConvBlock(\n",
" in_ch=in_ch,\n",
" out_ch=self.out_ch,\n",
" in_size=in_size//2,\n",
" depth=depth,\n",
" padding=padding,\n",
" bn=bn,\n",
" se=se,\n",
" act=act,\n",
" act_kwargs=act_kwargs)\n",
"\n",
" \n",
" def forward(self, x):\n",
" x=self.down(x)\n",
" return self.conv_block(x)\n",
"\n",
" \n",
"db_out=DownBlock(64,568,depth=4)\n",
"print(db_out.out_size,db_out.out_ch)\n",
"print(db_out(torch.rand(1,64,568,568)).shape)\n",
"summary(db_out,input_size=(64,568,568))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(196, 128)\n",
"torch.Size([1, 128, 196, 196])\n"
]
}
],
"source": [
"class UpBlock(nn.Module):\n",
" \n",
" @staticmethod\n",
" def cropping(skip_size,size):\n",
" return (skip_size-size)//2\n",
" \n",
" \n",
" def __init__(self,\n",
" in_ch,\n",
" in_size,\n",
" out_ch=None,\n",
" bilinear=False,\n",
" crop=None,\n",
" depth=2,\n",
" padding=0,\n",
" bn=True,\n",
" se=True,\n",
" act='relu',\n",
" act_kwargs={}):\n",
" super(UpBlock, self).__init__()\n",
" self.crop=crop\n",
" self.padding=padding\n",
" self.out_size=(in_size*2)-depth*(1-padding)*2\n",
" self.out_ch=out_ch or in_ch//2\n",
" if bilinear:\n",
" self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n",
" else:\n",
" self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)\n",
" self.conv_block=ConvBlock(\n",
" in_ch,\n",
" self.out_size,\n",
" out_ch=self.out_ch,\n",
" depth=depth,\n",
" padding=padding,\n",
" bn=bn,\n",
" se=se,\n",
" act=act,\n",
" act_kwargs=act_kwargs)\n",
" \n",
" \n",
" def forward(self, x, skip):\n",
" x = self.up(x)\n",
" skip = self._crop(skip,x)\n",
" x = torch.cat([skip, x], dim=1)\n",
" x = self.conv_block(x)\n",
" return x\n",
"\n",
" \n",
" def _crop(self,skip,x):\n",
" if self.padding is 0:\n",
" if self.crop is None:\n",
" self.crop=self.cropping(skip.size()[-1],x.size()[-1])\n",
" skip=skip[:,:,self.crop:-self.crop,self.crop:-self.crop]\n",
" return skip\n",
"\n",
" \n",
"db_out=UpBlock(256,100)\n",
"print(db_out.out_size,db_out.out_ch)\n",
"print(db_out(torch.rand(1,256,100,100),torch.rand(1,128,280,280)).shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(4, 2)\n",
"(388, 2)\n",
"torch.Size([1, 2, 388, 388])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 570, 570] 640\n",
" Conv2d-2 [-1, 64, 568, 568] 36,928\n",
" BatchNorm2d-3 [-1, 64, 568, 568] 128\n",
" AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n",
" Linear-5 [-1, 4] 260\n",
" ReLU-6 [-1, 4] 0\n",
" Linear-7 [-1, 64] 320\n",
" Sigmoid-8 [-1, 64] 0\n",
" SqueezeExcitation-9 [-1, 64, 568, 568] 0\n",
" ConvBlock-10 [-1, 64, 568, 568] 0\n",
" MaxPool2d-11 [-1, 64, 284, 284] 0\n",
" Conv2d-12 [-1, 128, 282, 282] 73,856\n",
" Conv2d-13 [-1, 128, 280, 280] 147,584\n",
" BatchNorm2d-14 [-1, 128, 280, 280] 256\n",
"AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n",
" Linear-16 [-1, 8] 1,032\n",
" ReLU-17 [-1, 8] 0\n",
" Linear-18 [-1, 128] 1,152\n",
" Sigmoid-19 [-1, 128] 0\n",
"SqueezeExcitation-20 [-1, 128, 280, 280] 0\n",
" ConvBlock-21 [-1, 128, 280, 280] 0\n",
" DownBlock-22 [-1, 128, 280, 280] 0\n",
" MaxPool2d-23 [-1, 128, 140, 140] 0\n",
" Conv2d-24 [-1, 256, 138, 138] 295,168\n",
" Conv2d-25 [-1, 256, 136, 136] 590,080\n",
" BatchNorm2d-26 [-1, 256, 136, 136] 512\n",
"AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n",
" Linear-28 [-1, 16] 4,112\n",
" ReLU-29 [-1, 16] 0\n",
" Linear-30 [-1, 256] 4,352\n",
" Sigmoid-31 [-1, 256] 0\n",
"SqueezeExcitation-32 [-1, 256, 136, 136] 0\n",
" ConvBlock-33 [-1, 256, 136, 136] 0\n",
" DownBlock-34 [-1, 256, 136, 136] 0\n",
" MaxPool2d-35 [-1, 256, 68, 68] 0\n",
" Conv2d-36 [-1, 512, 66, 66] 1,180,160\n",
" Conv2d-37 [-1, 512, 64, 64] 2,359,808\n",
" BatchNorm2d-38 [-1, 512, 64, 64] 1,024\n",
"AdaptiveAvgPool2d-39 [-1, 512, 1, 1] 0\n",
" Linear-40 [-1, 32] 16,416\n",
" ReLU-41 [-1, 32] 0\n",
" Linear-42 [-1, 512] 16,896\n",
" Sigmoid-43 [-1, 512] 0\n",
"SqueezeExcitation-44 [-1, 512, 64, 64] 0\n",
" ConvBlock-45 [-1, 512, 64, 64] 0\n",
" DownBlock-46 [-1, 512, 64, 64] 0\n",
" MaxPool2d-47 [-1, 512, 32, 32] 0\n",
" Conv2d-48 [-1, 1024, 30, 30] 4,719,616\n",
" Conv2d-49 [-1, 1024, 28, 28] 9,438,208\n",
" BatchNorm2d-50 [-1, 1024, 28, 28] 2,048\n",
"AdaptiveAvgPool2d-51 [-1, 1024, 1, 1] 0\n",
" Linear-52 [-1, 64] 65,600\n",
" ReLU-53 [-1, 64] 0\n",
" Linear-54 [-1, 1024] 66,560\n",
" Sigmoid-55 [-1, 1024] 0\n",
"SqueezeExcitation-56 [-1, 1024, 28, 28] 0\n",
" ConvBlock-57 [-1, 1024, 28, 28] 0\n",
" DownBlock-58 [-1, 1024, 28, 28] 0\n",
" ConvTranspose2d-59 [-1, 512, 56, 56] 2,097,664\n",
" Conv2d-60 [-1, 512, 54, 54] 4,719,104\n",
" Conv2d-61 [-1, 512, 52, 52] 2,359,808\n",
" BatchNorm2d-62 [-1, 512, 52, 52] 1,024\n",
"AdaptiveAvgPool2d-63 [-1, 512, 1, 1] 0\n",
" Linear-64 [-1, 32] 16,416\n",
" ReLU-65 [-1, 32] 0\n",
" Linear-66 [-1, 512] 16,896\n",
" Sigmoid-67 [-1, 512] 0\n",
"SqueezeExcitation-68 [-1, 512, 52, 52] 0\n",
" ConvBlock-69 [-1, 512, 52, 52] 0\n",
" UpBlock-70 [-1, 512, 52, 52] 0\n",
" ConvTranspose2d-71 [-1, 256, 104, 104] 524,544\n",
" Conv2d-72 [-1, 256, 102, 102] 1,179,904\n",
" Conv2d-73 [-1, 256, 100, 100] 590,080\n",
" BatchNorm2d-74 [-1, 256, 100, 100] 512\n",
"AdaptiveAvgPool2d-75 [-1, 256, 1, 1] 0\n",
" Linear-76 [-1, 16] 4,112\n",
" ReLU-77 [-1, 16] 0\n",
" Linear-78 [-1, 256] 4,352\n",
" Sigmoid-79 [-1, 256] 0\n",
"SqueezeExcitation-80 [-1, 256, 100, 100] 0\n",
" ConvBlock-81 [-1, 256, 100, 100] 0\n",
" UpBlock-82 [-1, 256, 100, 100] 0\n",
" ConvTranspose2d-83 [-1, 128, 200, 200] 131,200\n",
" Conv2d-84 [-1, 128, 198, 198] 295,040\n",
" Conv2d-85 [-1, 128, 196, 196] 147,584\n",
" BatchNorm2d-86 [-1, 128, 196, 196] 256\n",
"AdaptiveAvgPool2d-87 [-1, 128, 1, 1] 0\n",
" Linear-88 [-1, 8] 1,032\n",
" ReLU-89 [-1, 8] 0\n",
" Linear-90 [-1, 128] 1,152\n",
" Sigmoid-91 [-1, 128] 0\n",
"SqueezeExcitation-92 [-1, 128, 196, 196] 0\n",
" ConvBlock-93 [-1, 128, 196, 196] 0\n",
" UpBlock-94 [-1, 128, 196, 196] 0\n",
" ConvTranspose2d-95 [-1, 64, 392, 392] 32,832\n",
" Conv2d-96 [-1, 64, 390, 390] 73,792\n",
" Conv2d-97 [-1, 64, 388, 388] 36,928\n",
" BatchNorm2d-98 [-1, 64, 388, 388] 128\n",
"AdaptiveAvgPool2d-99 [-1, 64, 1, 1] 0\n",
" Linear-100 [-1, 4] 260\n",
" ReLU-101 [-1, 4] 0\n",
" Linear-102 [-1, 64] 320\n",
" Sigmoid-103 [-1, 64] 0\n",
"SqueezeExcitation-104 [-1, 64, 388, 388] 0\n",
" ConvBlock-105 [-1, 64, 388, 388] 0\n",
" UpBlock-106 [-1, 64, 388, 388] 0\n",
" Conv2d-107 [-1, 2, 388, 388] 130\n",
"================================================================\n",
"Total params: 31,257,786\n",
"Trainable params: 31,257,786\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"class UNet(nn.Module):\n",
"\n",
" def __init__(self,\n",
" network_depth=4,\n",
" conv_depth=2,\n",
" in_size=572,\n",
" in_ch=1,\n",
" out_ch=2,\n",
" init_ch=64,\n",
" padding=0,\n",
" bn=True,\n",
" se=True,\n",
" act='relu',\n",
" act_kwargs={}):\n",
" super(UNet, self).__init__()\n",
" self.network_depth=network_depth\n",
" self.conv_depth=conv_depth\n",
" self.out_ch=out_ch\n",
" self.padding=padding\n",
" self.input_conv=ConvBlock(\n",
" in_ch=in_ch,\n",
" in_size=in_size,\n",
" out_ch=init_ch,\n",
" depth=self.conv_depth,\n",
" padding=padding,\n",
" bn=bn,\n",
" se=se,\n",
" act=act,\n",
" act_kwargs=act_kwargs)\n",
" down_layers=self._down_layers(\n",
" self.input_conv.out_ch,\n",
" self.input_conv.out_size,\n",
" bn=bn,\n",
" se=se,\n",
" act=act,\n",
" act_kwargs=act_kwargs)\n",
" self.down_blocks=nn.ModuleList(down_layers)\n",
" up_layers=self._up_layers(\n",
" down_layers,\n",
" bn=bn,\n",
" se=se,\n",
" act=act,\n",
" act_kwargs=act_kwargs)\n",
" self.up_blocks=nn.ModuleList(up_layers)\n",
" self.out_size=self.up_blocks[-1].out_size\n",
" self.output_conv=self._output_layer(out_ch)\n",
"\n",
" \n",
" def forward(self, x):\n",
" x=self.input_conv(x)\n",
" skips=[x]\n",
" for block in self.down_blocks:\n",
" x=block(x)\n",
" skips.append(x)\n",
" skips.pop()\n",
" skips=skips[::-1]\n",
" for skip,block in zip(skips,self.up_blocks):\n",
" x=block(x,skip)\n",
" x=self.output_conv(x)\n",
" return x\n",
" \n",
" \n",
" def _down_layers(self,in_ch,in_size,bn,se,act,act_kwargs):\n",
" layers=[]\n",
" for index in range(1,self.network_depth+1):\n",
" layer=DownBlock(\n",
" in_ch,\n",
" in_size,\n",
" depth=self.conv_depth,\n",
" padding=self.padding,\n",
" bn=bn,\n",
" se=se,\n",
" act=act,\n",
" act_kwargs=act_kwargs)\n",
" in_ch=layer.out_ch\n",
" in_size=layer.out_size\n",
" layers.append(layer)\n",
" return layers\n",
"\n",
" \n",
" def _up_layers(self,down_layers,bn,se,act,act_kwargs):\n",
" down_layers=down_layers[::-1]\n",
" down_layers.append(self.input_conv)\n",
" first=down_layers.pop(0)\n",
" in_ch=first.out_ch\n",
" in_size=first.out_size\n",
" layers=[]\n",
" for down_layer in down_layers:\n",
" crop=UpBlock.cropping(down_layer.out_size,2*in_size)\n",
" layer=UpBlock(\n",
" in_ch,\n",
" in_size,\n",
" depth=self.conv_depth,\n",
" crop=crop,\n",
" padding=self.padding,\n",
" bn=bn,\n",
" se=se,\n",
" act=act,\n",
" act_kwargs=act_kwargs)\n",
" in_ch=layer.out_ch\n",
" in_size=layer.out_size\n",
" layers.append(layer)\n",
" return layers\n",
"\n",
" \n",
" def _output_layer(self,out_ch):\n",
" return nn.Conv2d(\n",
" in_channels=64,\n",
" out_channels=out_ch,\n",
" kernel_size=1,\n",
" stride=1,\n",
" padding=0)\n",
" \n",
" \n",
"unet=UNet(in_size=572,network_depth=4,conv_depth=2)\n",
"print(unet.network_depth,unet.conv_depth)\n",
"print(unet.out_size,unet.out_ch)\n",
"print(unet(torch.rand(1,1,572,572)).shape)\n",
"summary(unet,input_size=(1,572,572))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 4)\n",
"(492, 2)\n",
"torch.Size([1, 2, 492, 492])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 570, 570] 640\n",
" Conv2d-2 [-1, 64, 568, 568] 36,928\n",
" Conv2d-3 [-1, 64, 566, 566] 36,928\n",
" Conv2d-4 [-1, 64, 564, 564] 36,928\n",
" BatchNorm2d-5 [-1, 64, 564, 564] 128\n",
" AdaptiveAvgPool2d-6 [-1, 64, 1, 1] 0\n",
" Linear-7 [-1, 4] 260\n",
" ReLU-8 [-1, 4] 0\n",
" Linear-9 [-1, 64] 320\n",
" Sigmoid-10 [-1, 64] 0\n",
"SqueezeExcitation-11 [-1, 64, 564, 564] 0\n",
" ConvBlock-12 [-1, 64, 564, 564] 0\n",
" MaxPool2d-13 [-1, 64, 282, 282] 0\n",
" Conv2d-14 [-1, 128, 280, 280] 73,856\n",
" Conv2d-15 [-1, 128, 278, 278] 147,584\n",
" Conv2d-16 [-1, 128, 276, 276] 147,584\n",
" Conv2d-17 [-1, 128, 274, 274] 147,584\n",
" BatchNorm2d-18 [-1, 128, 274, 274] 256\n",
"AdaptiveAvgPool2d-19 [-1, 128, 1, 1] 0\n",
" Linear-20 [-1, 8] 1,032\n",
" ReLU-21 [-1, 8] 0\n",
" Linear-22 [-1, 128] 1,152\n",
" Sigmoid-23 [-1, 128] 0\n",
"SqueezeExcitation-24 [-1, 128, 274, 274] 0\n",
" ConvBlock-25 [-1, 128, 274, 274] 0\n",
" DownBlock-26 [-1, 128, 274, 274] 0\n",
" MaxPool2d-27 [-1, 128, 137, 137] 0\n",
" Conv2d-28 [-1, 256, 135, 135] 295,168\n",
" Conv2d-29 [-1, 256, 133, 133] 590,080\n",
" Conv2d-30 [-1, 256, 131, 131] 590,080\n",
" Conv2d-31 [-1, 256, 129, 129] 590,080\n",
" BatchNorm2d-32 [-1, 256, 129, 129] 512\n",
"AdaptiveAvgPool2d-33 [-1, 256, 1, 1] 0\n",
" Linear-34 [-1, 16] 4,112\n",
" ReLU-35 [-1, 16] 0\n",
" Linear-36 [-1, 256] 4,352\n",
" Sigmoid-37 [-1, 256] 0\n",
"SqueezeExcitation-38 [-1, 256, 129, 129] 0\n",
" ConvBlock-39 [-1, 256, 129, 129] 0\n",
" DownBlock-40 [-1, 256, 129, 129] 0\n",
" ConvTranspose2d-41 [-1, 128, 258, 258] 131,200\n",
" Conv2d-42 [-1, 128, 256, 256] 295,040\n",
" Conv2d-43 [-1, 128, 254, 254] 147,584\n",
" Conv2d-44 [-1, 128, 252, 252] 147,584\n",
" Conv2d-45 [-1, 128, 250, 250] 147,584\n",
" BatchNorm2d-46 [-1, 128, 250, 250] 256\n",
"AdaptiveAvgPool2d-47 [-1, 128, 1, 1] 0\n",
" Linear-48 [-1, 8] 1,032\n",
" ReLU-49 [-1, 8] 0\n",
" Linear-50 [-1, 128] 1,152\n",
" Sigmoid-51 [-1, 128] 0\n",
"SqueezeExcitation-52 [-1, 128, 250, 250] 0\n",
" ConvBlock-53 [-1, 128, 250, 250] 0\n",
" UpBlock-54 [-1, 128, 250, 250] 0\n",
" ConvTranspose2d-55 [-1, 64, 500, 500] 32,832\n",
" Conv2d-56 [-1, 64, 498, 498] 73,792\n",
" Conv2d-57 [-1, 64, 496, 496] 36,928\n",
" Conv2d-58 [-1, 64, 494, 494] 36,928\n",
" Conv2d-59 [-1, 64, 492, 492] 36,928\n",
" BatchNorm2d-60 [-1, 64, 492, 492] 128\n",
"AdaptiveAvgPool2d-61 [-1, 64, 1, 1] 0\n",
" Linear-62 [-1, 4] 260\n",
" ReLU-63 [-1, 4] 0\n",
" Linear-64 [-1, 64] 320\n",
" Sigmoid-65 [-1, 64] 0\n",
"SqueezeExcitation-66 [-1, 64, 492, 492] 0\n",
" ConvBlock-67 [-1, 64, 492, 492] 0\n",
" UpBlock-68 [-1, 64, 492, 492] 0\n",
" Conv2d-69 [-1, 2, 492, 492] 130\n",
"================================================================\n",
"Total params: 3,795,242\n",
"Trainable params: 3,795,242\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"unet=UNet(in_size=572,network_depth=2,conv_depth=4)\n",
"print(unet.network_depth,unet.conv_depth)\n",
"print(unet.out_size,unet.out_ch)\n",
"print(unet(torch.rand(1,1,572,572)).shape)\n",
"summary(unet,input_size=(1,572,572))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 2)\n",
"(216, 2)\n",
"torch.Size([1, 2, 216, 216])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 254, 254] 640\n",
" Conv2d-2 [-1, 64, 252, 252] 36,928\n",
" BatchNorm2d-3 [-1, 64, 252, 252] 128\n",
" AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n",
" Linear-5 [-1, 4] 260\n",
" ReLU-6 [-1, 4] 0\n",
" Linear-7 [-1, 64] 320\n",
" Sigmoid-8 [-1, 64] 0\n",
" SqueezeExcitation-9 [-1, 64, 252, 252] 0\n",
" ConvBlock-10 [-1, 64, 252, 252] 0\n",
" MaxPool2d-11 [-1, 64, 126, 126] 0\n",
" Conv2d-12 [-1, 128, 124, 124] 73,856\n",
" Conv2d-13 [-1, 128, 122, 122] 147,584\n",
" BatchNorm2d-14 [-1, 128, 122, 122] 256\n",
"AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n",
" Linear-16 [-1, 8] 1,032\n",
" ReLU-17 [-1, 8] 0\n",
" Linear-18 [-1, 128] 1,152\n",
" Sigmoid-19 [-1, 128] 0\n",
"SqueezeExcitation-20 [-1, 128, 122, 122] 0\n",
" ConvBlock-21 [-1, 128, 122, 122] 0\n",
" DownBlock-22 [-1, 128, 122, 122] 0\n",
" MaxPool2d-23 [-1, 128, 61, 61] 0\n",
" Conv2d-24 [-1, 256, 59, 59] 295,168\n",
" Conv2d-25 [-1, 256, 57, 57] 590,080\n",
" BatchNorm2d-26 [-1, 256, 57, 57] 512\n",
"AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n",
" Linear-28 [-1, 16] 4,112\n",
" ReLU-29 [-1, 16] 0\n",
" Linear-30 [-1, 256] 4,352\n",
" Sigmoid-31 [-1, 256] 0\n",
"SqueezeExcitation-32 [-1, 256, 57, 57] 0\n",
" ConvBlock-33 [-1, 256, 57, 57] 0\n",
" DownBlock-34 [-1, 256, 57, 57] 0\n",
" ConvTranspose2d-35 [-1, 128, 114, 114] 131,200\n",
" Conv2d-36 [-1, 128, 112, 112] 295,040\n",
" Conv2d-37 [-1, 128, 110, 110] 147,584\n",
" BatchNorm2d-38 [-1, 128, 110, 110] 256\n",
"AdaptiveAvgPool2d-39 [-1, 128, 1, 1] 0\n",
" Linear-40 [-1, 8] 1,032\n",
" ReLU-41 [-1, 8] 0\n",
" Linear-42 [-1, 128] 1,152\n",
" Sigmoid-43 [-1, 128] 0\n",
"SqueezeExcitation-44 [-1, 128, 110, 110] 0\n",
" ConvBlock-45 [-1, 128, 110, 110] 0\n",
" UpBlock-46 [-1, 128, 110, 110] 0\n",
" ConvTranspose2d-47 [-1, 64, 220, 220] 32,832\n",
" Conv2d-48 [-1, 64, 218, 218] 73,792\n",
" Conv2d-49 [-1, 64, 216, 216] 36,928\n",
" BatchNorm2d-50 [-1, 64, 216, 216] 128\n",
"AdaptiveAvgPool2d-51 [-1, 64, 1, 1] 0\n",
" Linear-52 [-1, 4] 260\n",
" ReLU-53 [-1, 4] 0\n",
" Linear-54 [-1, 64] 320\n",
" Sigmoid-55 [-1, 64] 0\n",
"SqueezeExcitation-56 [-1, 64, 216, 216] 0\n",
" ConvBlock-57 [-1, 64, 216, 216] 0\n",
" UpBlock-58 [-1, 64, 216, 216] 0\n",
" Conv2d-59 [-1, 2, 216, 216] 130\n",
"================================================================\n",
"Total params: 1,877,034\n",
"Trainable params: 1,877,034\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"SIZE=256\n",
"unet=UNet(in_size=SIZE,network_depth=2)\n",
"print(unet.network_depth,unet.conv_depth)\n",
"print(unet.out_size,unet.out_ch)\n",
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n",
"summary(unet,input_size=(1,SIZE,SIZE))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(5, 2)\n",
"(256, 2)\n",
"torch.Size([1, 2, 256, 256])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 256, 256] 640\n",
" Conv2d-2 [-1, 64, 256, 256] 36,928\n",
" BatchNorm2d-3 [-1, 64, 256, 256] 128\n",
" AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n",
" Linear-5 [-1, 4] 260\n",
" ReLU-6 [-1, 4] 0\n",
" Linear-7 [-1, 64] 320\n",
" Sigmoid-8 [-1, 64] 0\n",
" SqueezeExcitation-9 [-1, 64, 256, 256] 0\n",
" ConvBlock-10 [-1, 64, 256, 256] 0\n",
" MaxPool2d-11 [-1, 64, 128, 128] 0\n",
" Conv2d-12 [-1, 128, 128, 128] 73,856\n",
" Conv2d-13 [-1, 128, 128, 128] 147,584\n",
" BatchNorm2d-14 [-1, 128, 128, 128] 256\n",
"AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n",
" Linear-16 [-1, 8] 1,032\n",
" ReLU-17 [-1, 8] 0\n",
" Linear-18 [-1, 128] 1,152\n",
" Sigmoid-19 [-1, 128] 0\n",
"SqueezeExcitation-20 [-1, 128, 128, 128] 0\n",
" ConvBlock-21 [-1, 128, 128, 128] 0\n",
" DownBlock-22 [-1, 128, 128, 128] 0\n",
" MaxPool2d-23 [-1, 128, 64, 64] 0\n",
" Conv2d-24 [-1, 256, 64, 64] 295,168\n",
" Conv2d-25 [-1, 256, 64, 64] 590,080\n",
" BatchNorm2d-26 [-1, 256, 64, 64] 512\n",
"AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n",
" Linear-28 [-1, 16] 4,112\n",
" ReLU-29 [-1, 16] 0\n",
" Linear-30 [-1, 256] 4,352\n",
" Sigmoid-31 [-1, 256] 0\n",
"SqueezeExcitation-32 [-1, 256, 64, 64] 0\n",
" ConvBlock-33 [-1, 256, 64, 64] 0\n",
" DownBlock-34 [-1, 256, 64, 64] 0\n",
" MaxPool2d-35 [-1, 256, 32, 32] 0\n",
" Conv2d-36 [-1, 512, 32, 32] 1,180,160\n",
" Conv2d-37 [-1, 512, 32, 32] 2,359,808\n",
" BatchNorm2d-38 [-1, 512, 32, 32] 1,024\n",
"AdaptiveAvgPool2d-39 [-1, 512, 1, 1] 0\n",
" Linear-40 [-1, 32] 16,416\n",
" ReLU-41 [-1, 32] 0\n",
" Linear-42 [-1, 512] 16,896\n",
" Sigmoid-43 [-1, 512] 0\n",
"SqueezeExcitation-44 [-1, 512, 32, 32] 0\n",
" ConvBlock-45 [-1, 512, 32, 32] 0\n",
" DownBlock-46 [-1, 512, 32, 32] 0\n",
" MaxPool2d-47 [-1, 512, 16, 16] 0\n",
" Conv2d-48 [-1, 1024, 16, 16] 4,719,616\n",
" Conv2d-49 [-1, 1024, 16, 16] 9,438,208\n",
" BatchNorm2d-50 [-1, 1024, 16, 16] 2,048\n",
"AdaptiveAvgPool2d-51 [-1, 1024, 1, 1] 0\n",
" Linear-52 [-1, 64] 65,600\n",
" ReLU-53 [-1, 64] 0\n",
" Linear-54 [-1, 1024] 66,560\n",
" Sigmoid-55 [-1, 1024] 0\n",
"SqueezeExcitation-56 [-1, 1024, 16, 16] 0\n",
" ConvBlock-57 [-1, 1024, 16, 16] 0\n",
" DownBlock-58 [-1, 1024, 16, 16] 0\n",
" MaxPool2d-59 [-1, 1024, 8, 8] 0\n",
" Conv2d-60 [-1, 2048, 8, 8] 18,876,416\n",
" Conv2d-61 [-1, 2048, 8, 8] 37,750,784\n",
" BatchNorm2d-62 [-1, 2048, 8, 8] 4,096\n",
"AdaptiveAvgPool2d-63 [-1, 2048, 1, 1] 0\n",
" Linear-64 [-1, 128] 262,272\n",
" ReLU-65 [-1, 128] 0\n",
" Linear-66 [-1, 2048] 264,192\n",
" Sigmoid-67 [-1, 2048] 0\n",
"SqueezeExcitation-68 [-1, 2048, 8, 8] 0\n",
" ConvBlock-69 [-1, 2048, 8, 8] 0\n",
" DownBlock-70 [-1, 2048, 8, 8] 0\n",
" ConvTranspose2d-71 [-1, 1024, 16, 16] 8,389,632\n",
" Conv2d-72 [-1, 1024, 16, 16] 18,875,392\n",
" Conv2d-73 [-1, 1024, 16, 16] 9,438,208\n",
" BatchNorm2d-74 [-1, 1024, 16, 16] 2,048\n",
"AdaptiveAvgPool2d-75 [-1, 1024, 1, 1] 0\n",
" Linear-76 [-1, 64] 65,600\n",
" ReLU-77 [-1, 64] 0\n",
" Linear-78 [-1, 1024] 66,560\n",
" Sigmoid-79 [-1, 1024] 0\n",
"SqueezeExcitation-80 [-1, 1024, 16, 16] 0\n",
" ConvBlock-81 [-1, 1024, 16, 16] 0\n",
" UpBlock-82 [-1, 1024, 16, 16] 0\n",
" ConvTranspose2d-83 [-1, 512, 32, 32] 2,097,664\n",
" Conv2d-84 [-1, 512, 32, 32] 4,719,104\n",
" Conv2d-85 [-1, 512, 32, 32] 2,359,808\n",
" BatchNorm2d-86 [-1, 512, 32, 32] 1,024\n",
"AdaptiveAvgPool2d-87 [-1, 512, 1, 1] 0\n",
" Linear-88 [-1, 32] 16,416\n",
" ReLU-89 [-1, 32] 0\n",
" Linear-90 [-1, 512] 16,896\n",
" Sigmoid-91 [-1, 512] 0\n",
"SqueezeExcitation-92 [-1, 512, 32, 32] 0\n",
" ConvBlock-93 [-1, 512, 32, 32] 0\n",
" UpBlock-94 [-1, 512, 32, 32] 0\n",
" ConvTranspose2d-95 [-1, 256, 64, 64] 524,544\n",
" Conv2d-96 [-1, 256, 64, 64] 1,179,904\n",
" Conv2d-97 [-1, 256, 64, 64] 590,080\n",
" BatchNorm2d-98 [-1, 256, 64, 64] 512\n",
"AdaptiveAvgPool2d-99 [-1, 256, 1, 1] 0\n",
" Linear-100 [-1, 16] 4,112\n",
" ReLU-101 [-1, 16] 0\n",
" Linear-102 [-1, 256] 4,352\n",
" Sigmoid-103 [-1, 256] 0\n",
"SqueezeExcitation-104 [-1, 256, 64, 64] 0\n",
" ConvBlock-105 [-1, 256, 64, 64] 0\n",
" UpBlock-106 [-1, 256, 64, 64] 0\n",
" ConvTranspose2d-107 [-1, 128, 128, 128] 131,200\n",
" Conv2d-108 [-1, 128, 128, 128] 295,040\n",
" Conv2d-109 [-1, 128, 128, 128] 147,584\n",
" BatchNorm2d-110 [-1, 128, 128, 128] 256\n",
"AdaptiveAvgPool2d-111 [-1, 128, 1, 1] 0\n",
" Linear-112 [-1, 8] 1,032\n",
" ReLU-113 [-1, 8] 0\n",
" Linear-114 [-1, 128] 1,152\n",
" Sigmoid-115 [-1, 128] 0\n",
"SqueezeExcitation-116 [-1, 128, 128, 128] 0\n",
" ConvBlock-117 [-1, 128, 128, 128] 0\n",
" UpBlock-118 [-1, 128, 128, 128] 0\n",
" ConvTranspose2d-119 [-1, 64, 256, 256] 32,832\n",
" Conv2d-120 [-1, 64, 256, 256] 73,792\n",
" Conv2d-121 [-1, 64, 256, 256] 36,928\n",
" BatchNorm2d-122 [-1, 64, 256, 256] 128\n",
"AdaptiveAvgPool2d-123 [-1, 64, 1, 1] 0\n",
" Linear-124 [-1, 4] 260\n",
" ReLU-125 [-1, 4] 0\n",
" Linear-126 [-1, 64] 320\n",
" Sigmoid-127 [-1, 64] 0\n",
"SqueezeExcitation-128 [-1, 64, 256, 256] 0\n",
" ConvBlock-129 [-1, 64, 256, 256] 0\n",
" UpBlock-130 [-1, 64, 256, 256] 0\n",
" Conv2d-131 [-1, 2, 256, 256] 130\n",
"================================================================\n",
"Total params: 125,252,986\n",
"Trainable params: 125,252,986\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"SIZE=256\n",
"unet=UNet(in_size=SIZE,network_depth=5,padding=1)\n",
"print(unet.network_depth,unet.conv_depth)\n",
"print(unet.out_size,unet.out_ch)\n",
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n",
"summary(unet,input_size=(1,SIZE,SIZE))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 2)\n",
"(216, 2)\n",
"torch.Size([1, 2, 216, 216])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 254, 254] 640\n",
" Conv2d-2 [-1, 64, 252, 252] 36,928\n",
" ConvBlock-3 [-1, 64, 252, 252] 0\n",
" MaxPool2d-4 [-1, 64, 126, 126] 0\n",
" Conv2d-5 [-1, 128, 124, 124] 73,856\n",
" Conv2d-6 [-1, 128, 122, 122] 147,584\n",
" ConvBlock-7 [-1, 128, 122, 122] 0\n",
" DownBlock-8 [-1, 128, 122, 122] 0\n",
" MaxPool2d-9 [-1, 128, 61, 61] 0\n",
" Conv2d-10 [-1, 256, 59, 59] 295,168\n",
" Conv2d-11 [-1, 256, 57, 57] 590,080\n",
" ConvBlock-12 [-1, 256, 57, 57] 0\n",
" DownBlock-13 [-1, 256, 57, 57] 0\n",
" ConvTranspose2d-14 [-1, 128, 114, 114] 131,200\n",
" Conv2d-15 [-1, 128, 112, 112] 295,040\n",
" Conv2d-16 [-1, 128, 110, 110] 147,584\n",
" ConvBlock-17 [-1, 128, 110, 110] 0\n",
" UpBlock-18 [-1, 128, 110, 110] 0\n",
" ConvTranspose2d-19 [-1, 64, 220, 220] 32,832\n",
" Conv2d-20 [-1, 64, 218, 218] 73,792\n",
" Conv2d-21 [-1, 64, 216, 216] 36,928\n",
" ConvBlock-22 [-1, 64, 216, 216] 0\n",
" UpBlock-23 [-1, 64, 216, 216] 0\n",
" Conv2d-24 [-1, 2, 216, 216] 130\n",
"================================================================\n",
"Total params: 1,861,762\n",
"Trainable params: 1,861,762\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"SIZE=256\n",
"unet=UNet(in_size=SIZE,network_depth=2,bn=False,se=False)\n",
"print(unet.network_depth,unet.conv_depth)\n",
"print(unet.out_size,unet.out_ch)\n",
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n",
"summary(unet,input_size=(1,SIZE,SIZE))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 2)\n",
"(216, 2)\n",
"torch.Size([1, 2, 216, 216])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 254, 254] 640\n",
" Conv2d-2 [-1, 64, 252, 252] 36,928\n",
" AdaptiveAvgPool2d-3 [-1, 64, 1, 1] 0\n",
" Linear-4 [-1, 4] 260\n",
" ReLU-5 [-1, 4] 0\n",
" Linear-6 [-1, 64] 320\n",
" Sigmoid-7 [-1, 64] 0\n",
" SqueezeExcitation-8 [-1, 64, 252, 252] 0\n",
" ConvBlock-9 [-1, 64, 252, 252] 0\n",
" MaxPool2d-10 [-1, 64, 126, 126] 0\n",
" Conv2d-11 [-1, 128, 124, 124] 73,856\n",
" Conv2d-12 [-1, 128, 122, 122] 147,584\n",
"AdaptiveAvgPool2d-13 [-1, 128, 1, 1] 0\n",
" Linear-14 [-1, 8] 1,032\n",
" ReLU-15 [-1, 8] 0\n",
" Linear-16 [-1, 128] 1,152\n",
" Sigmoid-17 [-1, 128] 0\n",
"SqueezeExcitation-18 [-1, 128, 122, 122] 0\n",
" ConvBlock-19 [-1, 128, 122, 122] 0\n",
" DownBlock-20 [-1, 128, 122, 122] 0\n",
" MaxPool2d-21 [-1, 128, 61, 61] 0\n",
" Conv2d-22 [-1, 256, 59, 59] 295,168\n",
" Conv2d-23 [-1, 256, 57, 57] 590,080\n",
"AdaptiveAvgPool2d-24 [-1, 256, 1, 1] 0\n",
" Linear-25 [-1, 16] 4,112\n",
" ReLU-26 [-1, 16] 0\n",
" Linear-27 [-1, 256] 4,352\n",
" Sigmoid-28 [-1, 256] 0\n",
"SqueezeExcitation-29 [-1, 256, 57, 57] 0\n",
" ConvBlock-30 [-1, 256, 57, 57] 0\n",
" DownBlock-31 [-1, 256, 57, 57] 0\n",
" ConvTranspose2d-32 [-1, 128, 114, 114] 131,200\n",
" Conv2d-33 [-1, 128, 112, 112] 295,040\n",
" Conv2d-34 [-1, 128, 110, 110] 147,584\n",
"AdaptiveAvgPool2d-35 [-1, 128, 1, 1] 0\n",
" Linear-36 [-1, 8] 1,032\n",
" ReLU-37 [-1, 8] 0\n",
" Linear-38 [-1, 128] 1,152\n",
" Sigmoid-39 [-1, 128] 0\n",
"SqueezeExcitation-40 [-1, 128, 110, 110] 0\n",
" ConvBlock-41 [-1, 128, 110, 110] 0\n",
" UpBlock-42 [-1, 128, 110, 110] 0\n",
" ConvTranspose2d-43 [-1, 64, 220, 220] 32,832\n",
" Conv2d-44 [-1, 64, 218, 218] 73,792\n",
" Conv2d-45 [-1, 64, 216, 216] 36,928\n",
"AdaptiveAvgPool2d-46 [-1, 64, 1, 1] 0\n",
" Linear-47 [-1, 4] 260\n",
" ReLU-48 [-1, 4] 0\n",
" Linear-49 [-1, 64] 320\n",
" Sigmoid-50 [-1, 64] 0\n",
"SqueezeExcitation-51 [-1, 64, 216, 216] 0\n",
" ConvBlock-52 [-1, 64, 216, 216] 0\n",
" UpBlock-53 [-1, 64, 216, 216] 0\n",
" Conv2d-54 [-1, 2, 216, 216] 130\n",
"================================================================\n",
"Total params: 1,875,754\n",
"Trainable params: 1,875,754\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"SIZE=256\n",
"unet=UNet(in_size=SIZE,network_depth=2,bn=False,se=True)\n",
"print(unet.network_depth,unet.conv_depth)\n",
"print(unet.out_size,unet.out_ch)\n",
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n",
"summary(unet,input_size=(1,SIZE,SIZE))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2, 2)\n",
"(216, 2)\n",
"torch.Size([1, 2, 216, 216])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 254, 254] 640\n",
" Conv2d-2 [-1, 64, 252, 252] 36,928\n",
" BatchNorm2d-3 [-1, 64, 252, 252] 128\n",
" ConvBlock-4 [-1, 64, 252, 252] 0\n",
" MaxPool2d-5 [-1, 64, 126, 126] 0\n",
" Conv2d-6 [-1, 128, 124, 124] 73,856\n",
" Conv2d-7 [-1, 128, 122, 122] 147,584\n",
" BatchNorm2d-8 [-1, 128, 122, 122] 256\n",
" ConvBlock-9 [-1, 128, 122, 122] 0\n",
" DownBlock-10 [-1, 128, 122, 122] 0\n",
" MaxPool2d-11 [-1, 128, 61, 61] 0\n",
" Conv2d-12 [-1, 256, 59, 59] 295,168\n",
" Conv2d-13 [-1, 256, 57, 57] 590,080\n",
" BatchNorm2d-14 [-1, 256, 57, 57] 512\n",
" ConvBlock-15 [-1, 256, 57, 57] 0\n",
" DownBlock-16 [-1, 256, 57, 57] 0\n",
" ConvTranspose2d-17 [-1, 128, 114, 114] 131,200\n",
" Conv2d-18 [-1, 128, 112, 112] 295,040\n",
" Conv2d-19 [-1, 128, 110, 110] 147,584\n",
" BatchNorm2d-20 [-1, 128, 110, 110] 256\n",
" ConvBlock-21 [-1, 128, 110, 110] 0\n",
" UpBlock-22 [-1, 128, 110, 110] 0\n",
" ConvTranspose2d-23 [-1, 64, 220, 220] 32,832\n",
" Conv2d-24 [-1, 64, 218, 218] 73,792\n",
" Conv2d-25 [-1, 64, 216, 216] 36,928\n",
" BatchNorm2d-26 [-1, 64, 216, 216] 128\n",
" ConvBlock-27 [-1, 64, 216, 216] 0\n",
" UpBlock-28 [-1, 64, 216, 216] 0\n",
" Conv2d-29 [-1, 2, 216, 216] 130\n",
"================================================================\n",
"Total params: 1,863,042\n",
"Trainable params: 1,863,042\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"SIZE=256\n",
"unet=UNet(in_size=SIZE,network_depth=2,act='elu',se=False)\n",
"print(unet.network_depth,unet.conv_depth)\n",
"print(unet.out_size,unet.out_ch)\n",
"print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n",
"summary(unet,input_size=(1,SIZE,SIZE))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@brookisme
Copy link
Author

brookisme commented Jul 20, 2018

  • build unet with paper parameters
  • optional depth of unet
  • optional depth of conv_blocks
  • valid/same padding
  • other conv-block kwargs: se, bn, act

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment