Created
July 11, 2020 14:03
-
-
Save jcreinhold/faff34dc075bdbfe7d85bf2b16249516 to your computer and use it in GitHub Desktop.
DeepLabv3+ in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
DeepLabv3+ model (https://arxiv.org/abs/1802.02611) | |
Author: Jacob Reinhold (jacob.reinhold@jhu.edu) | |
""" | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision.models import resnet101 | |
from torchvision.models._utils import IntermediateLayerGetter | |
from torchvision.models.segmentation.deeplabv3 import ASPP | |
class ConvLayer(nn.Sequential): | |
def __init__(self, in_channels:int, out_channels:int): | |
super().__init__() | |
self.add_module('conv', nn.Conv2d(in_channels, out_channels, | |
3, padding=1, bias=False)) | |
self.add_module('norm', nn.BatchNorm2d(out_channels)) | |
self.add_module('relu', nn.ReLU(inplace=True)) | |
class DeepLabv3Plus(nn.Module): | |
def __init__(self, in_channels:int, out_channels:int, | |
pretrained_backbone:bool=True): | |
super().__init__() | |
backbone = resnet101( | |
pretrained=pretrained_backbone, | |
replace_stride_with_dilation=[False, True, True]) | |
# replace first conv for different input size | |
backbone.conv1 = nn.Conv2d( | |
in_channels, 64, kernel_size=7, | |
stride=2, padding=3, bias=False) | |
return_layers = { | |
'layer1': 'low_level_features', | |
'layer4': 'high_level_features'} | |
self.backbone = IntermediateLayerGetter( | |
backbone, return_layers=return_layers) | |
inplanes = 2048 | |
self.aspp = ASPP(inplanes, [6, 12, 18]) | |
# replace standard dropout with spatial dropout | |
self.aspp.project[3] = nn.Dropout2d(0.2) | |
self.low_level_feat_conv = nn.Conv2d(256, 48, 1) | |
self.up_conv = nn.Sequential( | |
ConvLayer(256+48, 256), | |
ConvLayer(256, 256)) | |
self.final_conv = nn.Sequential( | |
ConvLayer(256+in_channels, 256), | |
nn.Conv2d(256, out_channels, 1)) | |
@staticmethod | |
def interp_cat(x, skip): | |
x = F.interpolate(x, skip.shape[2:], mode='bilinear', align_corners=True) | |
return torch.cat((x, skip), 1) | |
def forward(self, x): | |
out = self.backbone(x) | |
llf = out.pop('low_level_features') | |
llf = self.low_level_feat_conv(llf) | |
hlf = out.pop('high_level_features') | |
hlf = self.aspp(hlf) | |
out = self.interp_cat(hlf, llf) | |
out = self.up_conv(out) | |
out = self.interp_cat(out, x) | |
out = self.final_conv(out) | |
return out | |
if __name__ == "__main__": | |
model = DeepLabv3Plus(1,1,pretrained_backbone=False) | |
print(model) | |
x = torch.randn(2,1,128,128) | |
y = model(x) | |
assert x.shape == y.shape |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Not exactly the DeepLabv3+ model as described, but pretty close. I tried to maximize the use of layers in the torchvision package since it implements the Deeplabv3 model. An important change is that the input is concatenated to the final convolutional layer. This implementation also uses normal convolutions instead of separable convolutions. Not tested extensively.