Created
January 22, 2019 17:59
-
-
Save Hsankesara/e3b064ff47d538052e059084b8d4df9f to your computer and use it in GitHub Desktop.
UNet main - https://github.com/Hsankesara/DeepResearch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
class UNet(nn.Module): | |
def contracting_block(self, in_channels, out_channels, kernel_size=3): | |
block = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(out_channels), | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(out_channels), | |
) | |
return block | |
def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3): | |
block = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(mid_channel), | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(mid_channel), | |
torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) | |
) | |
return block | |
def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3): | |
block = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(mid_channel), | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(mid_channel), | |
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(out_channels), | |
) | |
return block | |
def __init__(self, in_channel, out_channel): | |
super(UNet, self).__init__() | |
#Encode | |
self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64) | |
self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2) | |
self.conv_encode2 = self.contracting_block(64, 128) | |
self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2) | |
self.conv_encode3 = self.contracting_block(128, 256) | |
self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2) | |
# Bottleneck | |
self.bottleneck = torch.nn.Sequential( | |
torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(512), | |
torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm2d(512), | |
torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) | |
) | |
# Decode | |
self.conv_decode3 = self.expansive_block(512, 256, 128) | |
self.conv_decode2 = self.expansive_block(256, 128, 64) | |
self.final_layer = self.final_block(128, 64, out_channel) | |
def crop_and_concat(self, upsampled, bypass, crop=False): | |
if crop: | |
c = (bypass.size()[2] - upsampled.size()[2]) // 2 | |
bypass = F.pad(bypass, (-c, -c, -c, -c)) | |
return torch.cat((upsampled, bypass), 1) | |
def forward(self, x): | |
# Encode | |
encode_block1 = self.conv_encode1(x) | |
encode_pool1 = self.conv_maxpool1(encode_block1) | |
encode_block2 = self.conv_encode2(encode_pool1) | |
encode_pool2 = self.conv_maxpool2(encode_block2) | |
encode_block3 = self.conv_encode3(encode_pool2) | |
encode_pool3 = self.conv_maxpool3(encode_block3) | |
# Bottleneck | |
bottleneck1 = self.bottleneck(encode_pool3) | |
# Decode | |
decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True) | |
cat_layer2 = self.conv_decode3(decode_block3) | |
decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True) | |
cat_layer1 = self.conv_decode2(decode_block2) | |
decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True) | |
final_layer = self.final_layer(decode_block1) | |
return final_layer |
Thanks @jas332211,
I am glad you liked the code. The reason I did not add MaxPool in the contracting_block is that the output of contracting block has been used in crop_and_concat function as can be seen on line 85, 87 and 89. The MaxPool would distort the features as it halves the size of the input which does not perform as good as the current implementation. I have tried a few different architecture and found out that this configuration worked best of my dataset. Hope this helps.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey man, awesome code, just beautiful .
But one Q, why haven't you add a MaxPool to contracting_block?
You used it after each contracting_block anyway.
Did you do it to leave yourself more room for testing ?
If yes, what kind of ideas were you trying to test?
Again, such a nice code, makes my heart happy looking at it