Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created November 17, 2018 13:59
Show Gist options
  • Save ptrblck/27b4de4e291ffc0d85b33858d0bc8779 to your computer and use it in GitHub Desktop.
Save ptrblck/27b4de4e291ffc0d85b33858d0bc8779 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class UNet_down_block(torch.nn.Module):
def __init__(self, input_channel, output_channel, down_size):
super(UNet_down_block, self).__init__()
self.conv1 = torch.nn.Conv2d(input_channel, output_channel, 3, padding=1)
self.bn1 = torch.nn.BatchNorm2d(output_channel)
self.conv2 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
self.bn2 = torch.nn.BatchNorm2d(output_channel)
self.conv3 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
self.bn3 = torch.nn.BatchNorm2d(output_channel)
self.max_pool = torch.nn.MaxPool2d(2, 2)
self.relu = torch.nn.ReLU()
self.down_size = down_size
def forward(self, x):
if self.down_size:
x = self.max_pool(x)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
return x
class UNet_up_block(torch.nn.Module):
def __init__(self, prev_channel, input_channel, output_channel):
super(UNet_up_block, self).__init__()
self.up_sampling = torch.nn.Upsample(scale_factor=2, mode='bilinear')
self.conv1 = torch.nn.Conv2d(input_channel + input_channel, output_channel, 3, padding=1)
self.bn1 = torch.nn.BatchNorm2d(output_channel)
self.conv2 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
self.bn2 = torch.nn.BatchNorm2d(output_channel)
self.conv3 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
self.bn3 = torch.nn.BatchNorm2d(output_channel)
self.relu = torch.nn.ReLU()
# self.up1=torch.nn.ConvTranspose2d(12,25,3,stride=2,padding=1)
def forward(self, prev_feature_map, x,k):
# print('before up',x.size())
if k!=0:
x = self.up_sampling(x)
x = torch.cat((x, prev_feature_map), dim=1)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
return x
class UNet(torch.nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.down_block1 = UNet_down_block(3, 16, False)
self.down_block2 = UNet_down_block(16, 32, True)
self.down_block3 = UNet_down_block(32, 64, True)
self.mid_conv1 = torch.nn.Conv2d(64, 64, 3, padding=1)
self.bn1 = torch.nn.BatchNorm2d(64)
self.mid_conv2 = torch.nn.Conv2d(64, 64, 3, padding=1)
self.bn2 = torch.nn.BatchNorm2d(64)
self.mid_conv3 = torch.nn.Conv2d(64, 64, 3, padding=1)
self.bn3 = torch.nn.BatchNorm2d(64)
self.up_block5 = UNet_up_block(32, 64, 32)
self.up_block6 = UNet_up_block(16, 32, 16)
self.up_block7 = UNet_up_block(3, 16, 16)
self.last_conv1 = torch.nn.Conv2d(16, 3, 3, padding=1)
self.last_bn = torch.nn.BatchNorm2d(3)
self.last_conv2 = torch.nn.Conv2d(3, 1, 1, padding=0)
self.relu = torch.nn.ReLU()
self.max_pool = torch.nn.MaxPool2d(2, 2)
def forward(self, x):
# ins=x.clone()
self.x1 = self.down_block1(x)
# print('self.x1',self.x1.size())
self.x2 = self.down_block2(self.x1)
# print('self.x2',self.x2.size())
self.x3 = self.down_block3(self.x2)
# print('self.x3',self.x3.size())
# self.mid=self.max_pool(self.x3)
self.x7 = self.relu(self.bn1(self.mid_conv1(self.x3)))
self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))
self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))
# print('prev,x',self.x7.size(),self.x3.size())
x = self.up_block5(self.x3, self.x7,k=0)
x = self.up_block6(self.x2, x,k=1)
x=self.up_block7(self.x1,x,k=1)
x = self.relu(self.last_bn(self.last_conv1(x)))
x = self.last_conv2(x)
return x
def dice(input, taget):
smooth=.001
input=input.view(-1)
target=taget.view(-1)
return(1-2*(input*target).sum()/(input.sum()+taget.sum()+smooth))
net = UNet()
x = torch.randn(1, 3, 100, 100)
target = torch.randint(0, 2, (1, 1, 100, 100), dtype=torch.float32)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()
for epoch in range(20):
optimizer.zero_grad()
output = net(x)
bce_loss = criterion(output, target)
dice_loss = dice(output, target)
loss = bce_loss + dice_loss
loss.backward()
optimizer.step()
print('Epoch {}, loss {}, bce {}, dice {}'.format(
epoch, loss.item(), bce_loss.item(), dice_loss.item()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment