Skip to content

Instantly share code, notes, and snippets.

@jcreinhold
Last active Jul 29, 2020
Embed
What would you like to do?
Tiramisu 2D/3D in PyTorch
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
PyTorch implementation of the Tiramisu network architecture [1]
(2D) Implementation based on [2].
Changes from [2] include:
1) removal of bias from conv layers,
2) change zero padding to replication padding,
3) use of GELU for default activation,
4) cosmetic changes for brevity, clarity, consistency
References:
[1] Jégou, Simon, et al. "The one hundred layers tiramisu:
Fully convolutional densenets for semantic segmentation."
CVPR. 2017.
[2] https://github.com/bfortuner/pytorch_tiramisu
Author: Jacob Reinhold (jacob.reinhold@jhu.edu)
"""
__all__ = ['Tiramisu2d',
'Tiramisu3d']
from typing import *
from functools import partial
import torch
from torch import Tensor
from torch import nn
ACTIVATION = nn.GELU
class ConvLayer(nn.Sequential):
_conv = None
_dropout = None
_kernel_size = None
_maxpool = None
_norm = None
_pad = None
def __init__(self, in_channels:int, growth_rate:int, dropout_rate:float=0.2):
super().__init__()
self.dropout_rate = dropout_rate
self.add_module('norm', self._norm(in_channels))
self.add_module('act', ACTIVATION())
if self._use_padding():
self.add_module('pad', self._pad(self._kernel_size // 2))
self.add_module('conv', self._conv(in_channels, growth_rate,
self._kernel_size,
bias=False))
if self._use_dropout():
self.add_module('drop', self._dropout(dropout_rate))
if self._use_maxpool():
self.add_module('maxpool', self._maxpool(2))
def _use_dropout(self) -> bool:
return self.dropout_rate > 0.
def _use_padding(self) -> bool:
return self._kernel_size > 2
def _use_maxpool(self) -> bool:
return self._maxpool is not None
class ConvLayer2d(ConvLayer):
_conv = nn.Conv2d
_dropout = partial(nn.Dropout2d, inplace=True)
_kernel_size = 3
_maxpool = None
_norm = nn.BatchNorm2d
_pad = nn.ReplicationPad2d
class ConvLayer3d(ConvLayer):
_conv = nn.Conv3d
_dropout = partial(nn.Dropout3d, inplace=True)
_kernel_size = 3
_maxpool = None
_norm = nn.BatchNorm3d
_pad = nn.ReplicationPad3d
class DenseBlock(nn.Module):
_layer = None
def __init__(self, in_channels:int, growth_rate:int, n_layers:int,
upsample:bool=False, dropout_rate:float=0.2):
super().__init__()
self.in_channels = in_channels
self.growth_rate = growth_rate
self.n_layers = n_layers
self.upsample = upsample
self.dropout_rate = dropout_rate
self.layers = nn.ModuleList([
self._layer(ic, self.growth_rate, self.dropout_rate)
for ic in self.in_channels_range])
def forward(self, x:Tensor) -> Tensor:
if self.upsample:
new_features = []
# We pass all previous activations into each dense layer normally
# but we only store each dense layer's output in the new_features array.
# Note that all concatenation is done on the channel axis (i.e., 1)
for layer in self.layers:
out = layer(x)
x = torch.cat([x, out], 1)
new_features.append(out)
return torch.cat(new_features, 1)
else:
for layer in self.layers:
out = layer(x)
x = torch.cat([x, out], 1)
return x
@property
def in_channels_range(self) -> List[int]:
return [self.in_channels + i * self.growth_rate for i in range(self.n_layers)]
class DenseBlock2d(DenseBlock):
_layer = ConvLayer2d
class DenseBlock3d(DenseBlock):
_layer = ConvLayer3d
class TransitionDown2d(ConvLayer):
_conv = nn.Conv2d
_dropout = partial(nn.Dropout2d, inplace=True)
_kernel_size = 1
_maxpool = nn.MaxPool2d
_norm = nn.BatchNorm2d
_pad = nn.ReplicationPad2d
class TransitionDown3d(ConvLayer):
_conv = nn.Conv3d
_dropout = partial(nn.Dropout3d, inplace=True)
_kernel_size = 1
_maxpool = nn.MaxPool3d
_norm = nn.BatchNorm3d
_pad = nn.ReplicationPad3d
class TransitionUp(nn.Module):
_conv_trans = None
def __init__(self, in_channels:int, out_channels:int):
super().__init__()
kernel_size = 3
_crop = None
self.convTrans = self._conv_trans(
in_channels, out_channels, kernel_size,
stride=2, bias=False)
def forward(self, x:Tensor, skip:Tensor) -> Tensor:
out = self.convTrans(x)
out = self._crop_to_y(out, skip)
out = torch.cat([out, skip], 1)
return out
@staticmethod
def _crop_to_y(x:Tensor, y:Tensor) -> Tensor:
raise NotImplementedError
class TransitionUp2d(TransitionUp):
_conv_trans = nn.ConvTranspose2d
@staticmethod
def _crop_to_y(x:Tensor, y:Tensor) -> Tensor:
_, _, max_height, max_width = y.shape
_, _, h, w = x.size()
h = (h - max_height) // 2
w = (w - max_width) // 2
return x[:, :, h:(h + max_height), w:(w + max_width)]
class TransitionUp3d(TransitionUp):
_conv_trans = nn.ConvTranspose3d
@staticmethod
def _crop_to_y(x:Tensor, y:Tensor) -> Tensor:
_, _, max_height, max_width, max_depth = y.shape
_, _, h, w, d = x.size()
h = (h - max_height) // 2
w = (w - max_width) // 2
d = (d - max_depth) // 2
return x[:, :, h:(h + max_height), w:(w + max_width), d:(d + max_depth)]
class Bottleneck(nn.Sequential):
_layer = None
def __init__(self, in_channels:int, growth_rate:int, n_layers:int, dropout_rate:float=0.2):
super().__init__()
self.add_module('bottleneck', self._layer(
in_channels, growth_rate, n_layers,
upsample=True, dropout_rate=dropout_rate))
class Bottleneck2d(Bottleneck):
_layer = DenseBlock2d
class Bottleneck3d(Bottleneck):
_layer = DenseBlock3d
class Tiramisu(nn.Module):
_bottleneck = None
_conv = None
_denseblock = None
_pad = None
_trans_down = None
_trans_up = None
def __init__(self,
in_channels:int=3,
out_channels:int=1,
down_blocks:List[int]=(5,5,5,5,5),
up_blocks:List[int]=(5,5,5,5,5),
bottleneck_layers:int=5,
growth_rate:int=16,
out_chans_first_conv:int=48,
dropout_rate:float=0.2):
super().__init__()
self.down_blocks = down_blocks
self.up_blocks = up_blocks
first_kernel_size = 3
final_kernel_size = 1
skip_connection_channel_counts = []
self.firstConv = nn.Sequential(
self._pad(first_kernel_size // 2),
self._conv(in_channels, out_chans_first_conv,
first_kernel_size, bias=False))
cur_channels_count = out_chans_first_conv
## Downsampling path ##
self.denseBlocksDown = nn.ModuleList([])
self.transDownBlocks = nn.ModuleList([])
for n_layers in down_blocks:
self.denseBlocksDown.append(self._denseblock(
cur_channels_count, growth_rate, n_layers,
upsample=False, dropout_rate=dropout_rate))
cur_channels_count += (growth_rate*n_layers)
skip_connection_channel_counts.insert(0, cur_channels_count)
self.transDownBlocks.append(self._trans_down(
cur_channels_count, cur_channels_count,
dropout_rate=dropout_rate))
self.bottleneck = self._bottleneck(
cur_channels_count, growth_rate, bottleneck_layers,
dropout_rate=dropout_rate)
prev_block_channels = growth_rate*bottleneck_layers
cur_channels_count += prev_block_channels
## Upsampling path ##
self.transUpBlocks = nn.ModuleList([])
self.denseBlocksUp = nn.ModuleList([])
up_info = zip(up_blocks, skip_connection_channel_counts)
for i, (n_layers, sccc) in enumerate(up_info, 1):
self.transUpBlocks.append(self._trans_up(
prev_block_channels, prev_block_channels))
cur_channels_count = prev_block_channels + sccc
upsample = i < len(up_blocks) # do not upsample on last block
self.denseBlocksUp.append(self._denseblock(
cur_channels_count, growth_rate, n_layers,
upsample=upsample, dropout_rate=dropout_rate))
prev_block_channels = growth_rate*n_layers
cur_channels_count += prev_block_channels
self.finalConv = self._conv(cur_channels_count, out_channels,
final_kernel_size, bias=True)
def forward(self, x:Tensor) -> Tensor:
out = self.firstConv(x)
skip_connections = []
for dbd, tdb in zip(self.denseBlocksDown, self.transDownBlocks):
out = dbd(out)
skip_connections.append(out)
out = tdb(out)
out = self.bottleneck(out)
for ubd, tub in zip(self.denseBlocksUp, self.transUpBlocks):
skip = skip_connections.pop()
out = tub(out, skip)
out = ubd(out)
out = self.finalConv(out)
return out
class Tiramisu2d(Tiramisu):
_bottleneck = Bottleneck2d
_conv = nn.Conv2d
_denseblock = DenseBlock2d
_pad = nn.ReplicationPad2d
_trans_down = TransitionDown2d
_trans_up = TransitionUp2d
class Tiramisu3d(Tiramisu):
_bottleneck = Bottleneck3d
_conv = nn.Conv3d
_denseblock = DenseBlock3d
_pad = nn.ReplicationPad3d
_trans_down = TransitionDown3d
_trans_up = TransitionUp3d
if __name__ == "__main__":
net_kwargs = dict(in_channels=1, out_channels=1,
down_blocks=[2,2], up_blocks=[2,2],
bottleneck_layers=2)
x = torch.randn(1,1,32,32)
net2d = Tiramisu2d(**net_kwargs)
y = net2d(x)
assert y.shape == x.shape
x = torch.randn(1,1,32,32,32)
net3d = Tiramisu3d(**net_kwargs)
y = net3d(x)
assert y.shape == x.shape
@jcreinhold

This comment has been minimized.

Copy link
Owner Author

@jcreinhold jcreinhold commented Jul 11, 2020

Tiramisu model with minor changes in 2D and 3D. See docstring at the top of the file for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment