Skip to content

Instantly share code, notes, and snippets.

@Hsankesara
Created January 22, 2019 17:59
Show Gist options
  • Save Hsankesara/e3b064ff47d538052e059084b8d4df9f to your computer and use it in GitHub Desktop.
Save Hsankesara/e3b064ff47d538052e059084b8d4df9f to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
class UNet(nn.Module):
def contracting_block(self, in_channels, out_channels, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
)
return block
def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
)
return block
def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
)
return block
def __init__(self, in_channel, out_channel):
super(UNet, self).__init__()
#Encode
self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv_encode2 = self.contracting_block(64, 128)
self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
self.conv_encode3 = self.contracting_block(128, 256)
self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
# Bottleneck
self.bottleneck = torch.nn.Sequential(
torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(512),
torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(512),
torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
)
# Decode
self.conv_decode3 = self.expansive_block(512, 256, 128)
self.conv_decode2 = self.expansive_block(256, 128, 64)
self.final_layer = self.final_block(128, 64, out_channel)
def crop_and_concat(self, upsampled, bypass, crop=False):
if crop:
c = (bypass.size()[2] - upsampled.size()[2]) // 2
bypass = F.pad(bypass, (-c, -c, -c, -c))
return torch.cat((upsampled, bypass), 1)
def forward(self, x):
# Encode
encode_block1 = self.conv_encode1(x)
encode_pool1 = self.conv_maxpool1(encode_block1)
encode_block2 = self.conv_encode2(encode_pool1)
encode_pool2 = self.conv_maxpool2(encode_block2)
encode_block3 = self.conv_encode3(encode_pool2)
encode_pool3 = self.conv_maxpool3(encode_block3)
# Bottleneck
bottleneck1 = self.bottleneck(encode_pool3)
# Decode
decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
cat_layer2 = self.conv_decode3(decode_block3)
decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
cat_layer1 = self.conv_decode2(decode_block2)
decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
final_layer = self.final_layer(decode_block1)
return final_layer
@manvirvirk
Copy link

which dataset you are using?

@szczekulskij
Copy link

Hey man, awesome code, just beautiful .

But one Q, why haven't you add a MaxPool to contracting_block?
You used it after each contracting_block anyway.

Did you do it to leave yourself more room for testing ?
If yes, what kind of ideas were you trying to test?

Again, such a nice code, makes my heart happy looking at it

@Hsankesara
Copy link
Author

Hsankesara commented Jul 17, 2020

Thanks @jas332211,
I am glad you liked the code. The reason I did not add MaxPool in the contracting_block is that the output of contracting block has been used in crop_and_concat function as can be seen on line 85, 87 and 89. The MaxPool would distort the features as it halves the size of the input which does not perform as good as the current implementation. I have tried a few different architecture and found out that this configuration worked best of my dataset. Hope this helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment