-
-
Save Hsankesara/e3b064ff47d538052e059084b8d4df9f to your computer and use it in GitHub Desktop.
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
class UNet(nn.Module): | |
def contracting_block(self, in_channels, out_channels, kernel_size=3): | |
block = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(out_channels), | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(out_channels), | |
) | |
return block | |
def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3): | |
block = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(mid_channel), | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(mid_channel), | |
torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) | |
) | |
return block | |
def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3): | |
block = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(mid_channel), | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(mid_channel), | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(out_channels), | |
) | |
return block | |
def __init__(self, in_channel, out_channel): | |
super(UNet, self).__init__() | |
#Encode | |
self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64) | |
self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2) | |
self.conv_encode2 = self.contracting_block(64, 128) | |
self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2) | |
self.conv_encode3 = self.contracting_block(128, 256) | |
self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2) | |
# Bottleneck | |
self.bottleneck = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(512), | |
torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(512), | |
torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) | |
) | |
# Decode | |
self.conv_decode3 = self.expansive_block(512, 256, 128) | |
self.conv_decode2 = self.expansive_block(256, 128, 64) | |
self.final_layer = self.final_block(128, 64, out_channel) | |
def crop_and_concat(self, upsampled, bypass, crop=False): | |
if crop: | |
c = (bypass.size()[2] - upsampled.size()[2]) // 2 | |
bypass = F.pad(bypass, (-c, -c, -c, -c)) | |
return torch.cat((upsampled, bypass), 1) | |
def forward(self, x): | |
# Encode | |
encode_block1 = self.conv_encode1(x) | |
encode_pool1 = self.conv_maxpool1(encode_block1) | |
encode_block2 = self.conv_encode2(encode_pool1) | |
encode_pool2 = self.conv_maxpool2(encode_block2) | |
encode_block3 = self.conv_encode3(encode_pool2) | |
encode_pool3 = self.conv_maxpool3(encode_block3) | |
# Bottleneck | |
bottleneck1 = self.bottleneck(encode_pool3) | |
# Decode | |
decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True) | |
cat_layer2 = self.conv_decode3(decode_block3) | |
decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True) | |
cat_layer1 = self.conv_decode2(decode_block2) | |
decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True) | |
final_layer = self.final_layer(decode_block1) | |
return final_layer |
Hey man, awesome code, just beautiful .
But one Q, why haven't you add a MaxPool to contracting_block?
You used it after each contracting_block anyway.
Did you do it to leave yourself more room for testing ?
If yes, what kind of ideas were you trying to test?
Again, such a nice code, makes my heart happy looking at it
Thanks @jas332211,
I am glad you liked the code. The reason I did not add MaxPool in the contracting_block is that the output of contracting block has been used in crop_and_concat function as can be seen on line 85, 87 and 89. The MaxPool would distort the features as it halves the size of the input which does not perform as good as the current implementation. I have tried a few different architecture and found out that this configuration worked best of my dataset. Hope this helps.
which dataset you are using?