Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Created April 19, 2018 01:45
Show Gist options
  • Save KeremTurgutlu/089d0ba0c02a850eef91ded26ee4722f to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/089d0ba0c02a850eef91ded26ee4722f to your computer and use it in GitHub Desktop.
unet up block
# a sample up block
def make_conv_bn_relu(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
return [
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
]
self.up4 = nn.Sequential(
*make_conv_bn_relu(128,64, kernel_size=3, stride=1, padding=1 ),
*make_conv_bn_relu(64,64, kernel_size=3, stride=1, padding=1 )
)
self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1, stride=1, padding=0 )
# upsample out_last, concatenate with down1 and apply conv operations
out = F.upsample(out_last, scale_factor=2, mode='bilinear')
out = torch.cat([down1, out], 1)
out = self.up4(out)
# final 1x1 conv for predictions
final_out = self.final_conv(out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment