Skip to content

Instantly share code, notes, and snippets.

@jcreinhold
Created July 11, 2020 14:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jcreinhold/faff34dc075bdbfe7d85bf2b16249516 to your computer and use it in GitHub Desktop.
Save jcreinhold/faff34dc075bdbfe7d85bf2b16249516 to your computer and use it in GitHub Desktop.
DeepLabv3+ in PyTorch
"""
DeepLabv3+ model (https://arxiv.org/abs/1802.02611)
Author: Jacob Reinhold (jacob.reinhold@jhu.edu)
"""
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet101
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.segmentation.deeplabv3 import ASPP
class ConvLayer(nn.Sequential):
def __init__(self, in_channels:int, out_channels:int):
super().__init__()
self.add_module('conv', nn.Conv2d(in_channels, out_channels,
3, padding=1, bias=False))
self.add_module('norm', nn.BatchNorm2d(out_channels))
self.add_module('relu', nn.ReLU(inplace=True))
class DeepLabv3Plus(nn.Module):
def __init__(self, in_channels:int, out_channels:int,
pretrained_backbone:bool=True):
super().__init__()
backbone = resnet101(
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
# replace first conv for different input size
backbone.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7,
stride=2, padding=3, bias=False)
return_layers = {
'layer1': 'low_level_features',
'layer4': 'high_level_features'}
self.backbone = IntermediateLayerGetter(
backbone, return_layers=return_layers)
inplanes = 2048
self.aspp = ASPP(inplanes, [6, 12, 18])
# replace standard dropout with spatial dropout
self.aspp.project[3] = nn.Dropout2d(0.2)
self.low_level_feat_conv = nn.Conv2d(256, 48, 1)
self.up_conv = nn.Sequential(
ConvLayer(256+48, 256),
ConvLayer(256, 256))
self.final_conv = nn.Sequential(
ConvLayer(256+in_channels, 256),
nn.Conv2d(256, out_channels, 1))
@staticmethod
def interp_cat(x, skip):
x = F.interpolate(x, skip.shape[2:], mode='bilinear', align_corners=True)
return torch.cat((x, skip), 1)
def forward(self, x):
out = self.backbone(x)
llf = out.pop('low_level_features')
llf = self.low_level_feat_conv(llf)
hlf = out.pop('high_level_features')
hlf = self.aspp(hlf)
out = self.interp_cat(hlf, llf)
out = self.up_conv(out)
out = self.interp_cat(out, x)
out = self.final_conv(out)
return out
if __name__ == "__main__":
model = DeepLabv3Plus(1,1,pretrained_backbone=False)
print(model)
x = torch.randn(2,1,128,128)
y = model(x)
assert x.shape == y.shape
@jcreinhold
Copy link
Author

jcreinhold commented Jul 11, 2020

Not exactly the DeepLabv3+ model as described, but pretty close. I tried to maximize the use of layers in the torchvision package since it implements the Deeplabv3 model. An important change is that the input is concatenated to the final convolutional layer. This implementation also uses normal convolutions instead of separable convolutions. Not tested extensively.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment