Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
DeepLabv3+ in PyTorch
"""
DeepLabv3+ model (https://arxiv.org/abs/1802.02611)
Author: Jacob Reinhold (jacob.reinhold@jhu.edu)
"""
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet101
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.segmentation.deeplabv3 import ASPP
class ConvLayer(nn.Sequential):
def __init__(self, in_channels:int, out_channels:int):
super().__init__()
self.add_module('conv', nn.Conv2d(in_channels, out_channels,
3, padding=1, bias=False))
self.add_module('norm', nn.BatchNorm2d(out_channels))
self.add_module('relu', nn.ReLU(inplace=True))
class DeepLabv3Plus(nn.Module):
def __init__(self, in_channels:int, out_channels:int,
pretrained_backbone:bool=True):
super().__init__()
backbone = resnet101(
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
# replace first conv for different input size
backbone.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7,
stride=2, padding=3, bias=False)
return_layers = {
'layer1': 'low_level_features',
'layer4': 'high_level_features'}
self.backbone = IntermediateLayerGetter(
backbone, return_layers=return_layers)
inplanes = 2048
self.aspp = ASPP(inplanes, [6, 12, 18])
# replace standard dropout with spatial dropout
self.aspp.project[3] = nn.Dropout2d(0.2)
self.low_level_feat_conv = nn.Conv2d(256, 48, 1)
self.up_conv = nn.Sequential(
ConvLayer(256+48, 256),
ConvLayer(256, 256))
self.final_conv = nn.Sequential(
ConvLayer(256+in_channels, 256),
nn.Conv2d(256, out_channels, 1))
@staticmethod
def interp_cat(x, skip):
x = F.interpolate(x, skip.shape[2:], mode='bilinear', align_corners=True)
return torch.cat((x, skip), 1)
def forward(self, x):
out = self.backbone(x)
llf = out.pop('low_level_features')
llf = self.low_level_feat_conv(llf)
hlf = out.pop('high_level_features')
hlf = self.aspp(hlf)
out = self.interp_cat(hlf, llf)
out = self.up_conv(out)
out = self.interp_cat(out, x)
out = self.final_conv(out)
return out
if __name__ == "__main__":
model = DeepLabv3Plus(1,1,pretrained_backbone=False)
print(model)
x = torch.randn(2,1,128,128)
y = model(x)
assert x.shape == y.shape
@jcreinhold
Copy link
Author

jcreinhold commented Jul 11, 2020

Not exactly the DeepLabv3+ model as described, but pretty close. I tried to maximize the use of layers in the torchvision package since it implements the Deeplabv3 model. An important change is that the input is concatenated to the final convolutional layer. This implementation also uses normal convolutions instead of separable convolutions. Not tested extensively.

Loading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment