Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
DeepLabv3+ in PyTorch
DeepLabv3+ model (
Author: Jacob Reinhold (
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet101
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.segmentation.deeplabv3 import ASPP
class ConvLayer(nn.Sequential):
def __init__(self, in_channels:int, out_channels:int):
self.add_module('conv', nn.Conv2d(in_channels, out_channels,
3, padding=1, bias=False))
self.add_module('norm', nn.BatchNorm2d(out_channels))
self.add_module('relu', nn.ReLU(inplace=True))
class DeepLabv3Plus(nn.Module):
def __init__(self, in_channels:int, out_channels:int,
backbone = resnet101(
replace_stride_with_dilation=[False, True, True])
# replace first conv for different input size
backbone.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7,
stride=2, padding=3, bias=False)
return_layers = {
'layer1': 'low_level_features',
'layer4': 'high_level_features'}
self.backbone = IntermediateLayerGetter(
backbone, return_layers=return_layers)
inplanes = 2048
self.aspp = ASPP(inplanes, [6, 12, 18])
# replace standard dropout with spatial dropout
self.aspp.project[3] = nn.Dropout2d(0.2)
self.low_level_feat_conv = nn.Conv2d(256, 48, 1)
self.up_conv = nn.Sequential(
ConvLayer(256+48, 256),
ConvLayer(256, 256))
self.final_conv = nn.Sequential(
ConvLayer(256+in_channels, 256),
nn.Conv2d(256, out_channels, 1))
def interp_cat(x, skip):
x = F.interpolate(x, skip.shape[2:], mode='bilinear', align_corners=True)
return, skip), 1)
def forward(self, x):
out = self.backbone(x)
llf = out.pop('low_level_features')
llf = self.low_level_feat_conv(llf)
hlf = out.pop('high_level_features')
hlf = self.aspp(hlf)
out = self.interp_cat(hlf, llf)
out = self.up_conv(out)
out = self.interp_cat(out, x)
out = self.final_conv(out)
return out
if __name__ == "__main__":
model = DeepLabv3Plus(1,1,pretrained_backbone=False)
x = torch.randn(2,1,128,128)
y = model(x)
assert x.shape == y.shape
Copy link

jcreinhold commented Jul 11, 2020

Not exactly the DeepLabv3+ model as described, but pretty close. I tried to maximize the use of layers in the torchvision package since it implements the Deeplabv3 model. An important change is that the input is concatenated to the final convolutional layer. This implementation also uses normal convolutions instead of separable convolutions. Not tested extensively.


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment