Skip to content

Instantly share code, notes, and snippets.

@BloodAxe
Created January 14, 2021 14:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BloodAxe/f106e882c757c8929d2a94f4a3b6507f to your computer and use it in GitHub Desktop.
Save BloodAxe/f106e882c757c8929d2a94f4a3b6507f to your computer and use it in GitHub Desktop.
from collections import OrderedDict
from functools import partial
from typing import Union, List, Dict, Tuple, Type
from pytorch_toolbelt.modules import (
conv1x1,
UnetBlock,
ACT_RELU,
ABN,
ACT_SWISH,
ResidualDeconvolutionUpsample2d,
DeconvolutionUpsample2d,
)
from pytorch_toolbelt.modules import encoders as E
from pytorch_toolbelt.modules import decoders as D
from torch import nn, Tensor
from torch.nn import functional as F
from ..dataset import OUTPUT_MASK_KEY, name_for_stride
from catalyst.registry import Model
class UnetSegmentationModel(nn.Module):
def __init__(
self,
encoder: E.EncoderModule,
unet_channels: Union[int, List[int]],
num_classes: int = 1,
dropout=0.25,
activation=ACT_RELU,
upsample_block: Tuple[
Type[nn.UpsamplingBilinear2d], Type[nn.UpsamplingNearest2d], Type[ResidualDeconvolutionUpsample2d]
] = nn.UpsamplingNearest2d,
need_supervision_masks=False,
last_upsample_block=None,
):
super().__init__()
self.encoder = encoder
abn_block = partial(ABN, activation=activation)
self.decoder = D.UNetDecoder(
feature_maps=encoder.channels,
decoder_features=unet_channels,
unet_block=partial(UnetBlock, abn_block=abn_block),
upsample_block=upsample_block,
)
if last_upsample_block is not None:
self.last_upsample_block = last_upsample_block(unet_channels[0])
self.mask = nn.Sequential(
OrderedDict(
[
("drop", nn.Dropout2d(dropout)),
(
"conv",
nn.Conv2d(self.last_upsample_block.out_channels, num_classes, kernel_size=3, padding=1),
),
]
)
)
else:
self.last_upsample_block = None
self.mask = nn.Sequential(
OrderedDict(
[
("drop", nn.Dropout2d(dropout)),
("conv", nn.Conv2d(unet_channels[0], num_classes, kernel_size=3, padding=1)),
]
)
)
if need_supervision_masks:
num_blocks = len(self.decoder.channels)
self.supervision = nn.ModuleList([conv1x1(channels, num_classes) for channels in self.decoder.channels])
self.supervision_names = [
name_for_stride(OUTPUT_MASK_KEY, stride) for stride in self.encoder.strides[:num_blocks]
]
else:
self.supervision = None
self.supervision_names = None
def forward(self, x: Tensor) -> Dict[str, Tensor]:
image_size = x.size()
x = self.encoder(x)
x = self.decoder(x)
# Decode mask
if self.last_upsample_block is not None:
mask = self.mask(self.last_upsample_block(x[0]))
else:
mask = self.mask(x[0])
if mask.size()[2:] != image_size[2:]:
mask = F.interpolate(mask, size=image_size[2:], mode="bilinear", align_corners=False)
output = {OUTPUT_MASK_KEY: mask}
if self.supervision is not None:
for feature_map, supervision, name in zip(x, self.supervision, self.supervision_names):
output[name] = supervision(feature_map)
return output
@Model
def b6_unet32_s2_rdtc(input_channels=3, num_classes=1, dropout=0.2, need_supervision_masks=False, pretrained=True):
encoder = E.B6Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4])
if input_channels != 3:
encoder.change_input_channels(input_channels)
return UnetSegmentationModel(
encoder,
num_classes=num_classes,
unet_channels=[32, 64, 128, 256],
activation=ACT_SWISH,
dropout=dropout,
need_supervision_masks=need_supervision_masks,
upsample_block=ResidualDeconvolutionUpsample2d,
last_upsample_block=ResidualDeconvolutionUpsample2d,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment