Last active
January 28, 2022 17:49
-
-
Save jcreinhold/f186b3a12333227cc55d5f0f121ede28 to your computer and use it in GitHub Desktop.
Tiramisu 2D/3D in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
PyTorch implementation of the Tiramisu network architecture [1] | |
(2D) Implementation based on [2]. | |
Changes from [2] include: | |
1) removal of bias from conv layers, | |
2) change zero padding to replication padding, | |
3) use of GELU for default activation, | |
4) cosmetic changes for brevity, clarity, consistency | |
References: | |
[1] Jégou, Simon, et al. "The one hundred layers tiramisu: | |
Fully convolutional densenets for semantic segmentation." | |
CVPR. 2017. | |
[2] https://github.com/bfortuner/pytorch_tiramisu | |
Author: Jacob Reinhold (jacob.reinhold@jhu.edu) | |
""" | |
__all__ = ['Tiramisu2d', | |
'Tiramisu3d'] | |
from typing import * | |
from functools import partial | |
import torch | |
from torch import Tensor | |
from torch import nn | |
ACTIVATION = nn.GELU | |
class ConvLayer(nn.Sequential): | |
_conv = None | |
_dropout = None | |
_kernel_size = None | |
_maxpool = None | |
_norm = None | |
_pad = None | |
def __init__(self, in_channels:int, growth_rate:int, dropout_rate:float=0.2): | |
super().__init__() | |
self.dropout_rate = dropout_rate | |
self.add_module('norm', self._norm(in_channels)) | |
self.add_module('act', ACTIVATION()) | |
if self._use_padding(): | |
self.add_module('pad', self._pad(self._kernel_size // 2)) | |
self.add_module('conv', self._conv(in_channels, growth_rate, | |
self._kernel_size, | |
bias=False)) | |
if self._use_dropout(): | |
self.add_module('drop', self._dropout(dropout_rate)) | |
if self._use_maxpool(): | |
self.add_module('maxpool', self._maxpool(2)) | |
def _use_dropout(self) -> bool: | |
return self.dropout_rate > 0. | |
def _use_padding(self) -> bool: | |
return self._kernel_size > 2 | |
def _use_maxpool(self) -> bool: | |
return self._maxpool is not None | |
class ConvLayer2d(ConvLayer): | |
_conv = nn.Conv2d | |
_dropout = partial(nn.Dropout2d, inplace=True) | |
_kernel_size = 3 | |
_maxpool = None | |
_norm = nn.BatchNorm2d | |
_pad = nn.ReplicationPad2d | |
class ConvLayer3d(ConvLayer): | |
_conv = nn.Conv3d | |
_dropout = partial(nn.Dropout3d, inplace=True) | |
_kernel_size = 3 | |
_maxpool = None | |
_norm = nn.BatchNorm3d | |
_pad = nn.ReplicationPad3d | |
class DenseBlock(nn.Module): | |
_layer = None | |
def __init__(self, in_channels:int, growth_rate:int, n_layers:int, | |
upsample:bool=False, dropout_rate:float=0.2): | |
super().__init__() | |
self.in_channels = in_channels | |
self.growth_rate = growth_rate | |
self.n_layers = n_layers | |
self.upsample = upsample | |
self.dropout_rate = dropout_rate | |
self.layers = nn.ModuleList([ | |
self._layer(ic, self.growth_rate, self.dropout_rate) | |
for ic in self.in_channels_range]) | |
def forward(self, x:Tensor) -> Tensor: | |
if self.upsample: | |
new_features = [] | |
# We pass all previous activations into each dense layer normally | |
# but we only store each dense layer's output in the new_features array. | |
# Note that all concatenation is done on the channel axis (i.e., 1) | |
for layer in self.layers: | |
out = layer(x) | |
x = torch.cat([x, out], 1) | |
new_features.append(out) | |
return torch.cat(new_features, 1) | |
else: | |
for layer in self.layers: | |
out = layer(x) | |
x = torch.cat([x, out], 1) | |
return x | |
@property | |
def in_channels_range(self) -> List[int]: | |
return [self.in_channels + i * self.growth_rate for i in range(self.n_layers)] | |
class DenseBlock2d(DenseBlock): | |
_layer = ConvLayer2d | |
class DenseBlock3d(DenseBlock): | |
_layer = ConvLayer3d | |
class TransitionDown2d(ConvLayer): | |
_conv = nn.Conv2d | |
_dropout = partial(nn.Dropout2d, inplace=True) | |
_kernel_size = 1 | |
_maxpool = nn.MaxPool2d | |
_norm = nn.BatchNorm2d | |
_pad = nn.ReplicationPad2d | |
class TransitionDown3d(ConvLayer): | |
_conv = nn.Conv3d | |
_dropout = partial(nn.Dropout3d, inplace=True) | |
_kernel_size = 1 | |
_maxpool = nn.MaxPool3d | |
_norm = nn.BatchNorm3d | |
_pad = nn.ReplicationPad3d | |
class TransitionUp(nn.Module): | |
_conv_trans = None | |
def __init__(self, in_channels:int, out_channels:int): | |
super().__init__() | |
kernel_size = 3 | |
_crop = None | |
self.convTrans = self._conv_trans( | |
in_channels, out_channels, kernel_size, | |
stride=2, bias=False) | |
def forward(self, x:Tensor, skip:Tensor) -> Tensor: | |
out = self.convTrans(x) | |
out = self._crop_to_y(out, skip) | |
out = torch.cat([out, skip], 1) | |
return out | |
@staticmethod | |
def _crop_to_y(x:Tensor, y:Tensor) -> Tensor: | |
raise NotImplementedError | |
class TransitionUp2d(TransitionUp): | |
_conv_trans = nn.ConvTranspose2d | |
@staticmethod | |
def _crop_to_y(x:Tensor, y:Tensor) -> Tensor: | |
_, _, max_height, max_width = y.shape | |
_, _, h, w = x.size() | |
h = (h - max_height) // 2 | |
w = (w - max_width) // 2 | |
return x[:, :, h:(h + max_height), w:(w + max_width)] | |
class TransitionUp3d(TransitionUp): | |
_conv_trans = nn.ConvTranspose3d | |
@staticmethod | |
def _crop_to_y(x:Tensor, y:Tensor) -> Tensor: | |
_, _, max_height, max_width, max_depth = y.shape | |
_, _, h, w, d = x.size() | |
h = (h - max_height) // 2 | |
w = (w - max_width) // 2 | |
d = (d - max_depth) // 2 | |
return x[:, :, h:(h + max_height), w:(w + max_width), d:(d + max_depth)] | |
class Bottleneck(nn.Sequential): | |
_layer = None | |
def __init__(self, in_channels:int, growth_rate:int, n_layers:int, dropout_rate:float=0.2): | |
super().__init__() | |
self.add_module('bottleneck', self._layer( | |
in_channels, growth_rate, n_layers, | |
upsample=True, dropout_rate=dropout_rate)) | |
class Bottleneck2d(Bottleneck): | |
_layer = DenseBlock2d | |
class Bottleneck3d(Bottleneck): | |
_layer = DenseBlock3d | |
class Tiramisu(nn.Module): | |
_bottleneck = None | |
_conv = None | |
_denseblock = None | |
_pad = None | |
_trans_down = None | |
_trans_up = None | |
def __init__(self, | |
in_channels:int=3, | |
out_channels:int=1, | |
down_blocks:List[int]=(5,5,5,5,5), | |
up_blocks:List[int]=(5,5,5,5,5), | |
bottleneck_layers:int=5, | |
growth_rate:int=16, | |
out_chans_first_conv:int=48, | |
dropout_rate:float=0.2): | |
super().__init__() | |
self.down_blocks = down_blocks | |
self.up_blocks = up_blocks | |
first_kernel_size = 3 | |
final_kernel_size = 1 | |
skip_connection_channel_counts = [] | |
self.firstConv = nn.Sequential( | |
self._pad(first_kernel_size // 2), | |
self._conv(in_channels, out_chans_first_conv, | |
first_kernel_size, bias=False)) | |
cur_channels_count = out_chans_first_conv | |
## Downsampling path ## | |
self.denseBlocksDown = nn.ModuleList([]) | |
self.transDownBlocks = nn.ModuleList([]) | |
for n_layers in down_blocks: | |
self.denseBlocksDown.append(self._denseblock( | |
cur_channels_count, growth_rate, n_layers, | |
upsample=False, dropout_rate=dropout_rate)) | |
cur_channels_count += (growth_rate*n_layers) | |
skip_connection_channel_counts.insert(0, cur_channels_count) | |
self.transDownBlocks.append(self._trans_down( | |
cur_channels_count, cur_channels_count, | |
dropout_rate=dropout_rate)) | |
self.bottleneck = self._bottleneck( | |
cur_channels_count, growth_rate, bottleneck_layers, | |
dropout_rate=dropout_rate) | |
prev_block_channels = growth_rate*bottleneck_layers | |
cur_channels_count += prev_block_channels | |
## Upsampling path ## | |
self.transUpBlocks = nn.ModuleList([]) | |
self.denseBlocksUp = nn.ModuleList([]) | |
up_info = zip(up_blocks, skip_connection_channel_counts) | |
for i, (n_layers, sccc) in enumerate(up_info, 1): | |
self.transUpBlocks.append(self._trans_up( | |
prev_block_channels, prev_block_channels)) | |
cur_channels_count = prev_block_channels + sccc | |
upsample = i < len(up_blocks) # do not upsample on last block | |
self.denseBlocksUp.append(self._denseblock( | |
cur_channels_count, growth_rate, n_layers, | |
upsample=upsample, dropout_rate=dropout_rate)) | |
prev_block_channels = growth_rate*n_layers | |
cur_channels_count += prev_block_channels | |
self.finalConv = self._conv(cur_channels_count, out_channels, | |
final_kernel_size, bias=True) | |
def forward(self, x:Tensor) -> Tensor: | |
out = self.firstConv(x) | |
skip_connections = [] | |
for dbd, tdb in zip(self.denseBlocksDown, self.transDownBlocks): | |
out = dbd(out) | |
skip_connections.append(out) | |
out = tdb(out) | |
out = self.bottleneck(out) | |
for ubd, tub in zip(self.denseBlocksUp, self.transUpBlocks): | |
skip = skip_connections.pop() | |
out = tub(out, skip) | |
out = ubd(out) | |
out = self.finalConv(out) | |
return out | |
class Tiramisu2d(Tiramisu): | |
_bottleneck = Bottleneck2d | |
_conv = nn.Conv2d | |
_denseblock = DenseBlock2d | |
_pad = nn.ReplicationPad2d | |
_trans_down = TransitionDown2d | |
_trans_up = TransitionUp2d | |
class Tiramisu3d(Tiramisu): | |
_bottleneck = Bottleneck3d | |
_conv = nn.Conv3d | |
_denseblock = DenseBlock3d | |
_pad = nn.ReplicationPad3d | |
_trans_down = TransitionDown3d | |
_trans_up = TransitionUp3d | |
if __name__ == "__main__": | |
net_kwargs = dict(in_channels=1, out_channels=1, | |
down_blocks=[2,2], up_blocks=[2,2], | |
bottleneck_layers=2) | |
x = torch.randn(1,1,32,32) | |
net2d = Tiramisu2d(**net_kwargs) | |
y = net2d(x) | |
assert y.shape == x.shape | |
x = torch.randn(1,1,32,32,32) | |
net3d = Tiramisu3d(**net_kwargs) | |
y = net3d(x) | |
assert y.shape == x.shape |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Tiramisu model with minor changes in 2D and 3D. See docstring at the top of the file for more details.