Skip to content

Instantly share code, notes, and snippets.

@zhanghang1989
Created September 4, 2019 02:46
Show Gist options
  • Save zhanghang1989/68b94e79892d420eeabf6e25edae8133 to your computer and use it in GitHub Desktop.
Save zhanghang1989/68b94e79892d420eeabf6e25edae8133 to your computer and use it in GitHub Desktop.
Co-occurrent Features in Semantic Segmentation
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2018
###########################################################################
from __future__ import division
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import interpolate
from .base import BaseNet
from ..nn import ACFModule, ConcurrentModule, SyncBatchNorm
from .fcn import FCNHead
from .encnet import EncModule
__all__ = ['ATTEN', 'get_atten']
class ATTEN(BaseNet):
def __init__(self, nclass, backbone, nheads=8, nmixs=1, with_global=True,
with_enc=True, with_lateral=False, aux=True, se_loss=False,
norm_layer=SyncBatchNorm, **kwargs):
super(ATTEN, self).__init__(nclass, backbone, aux, se_loss,
norm_layer=norm_layer, **kwargs)
in_channels = 4096 if self.backbone.startswith('wideresnet') else 2048
self.head = ATTENHead(in_channels, nclass, norm_layer, self._up_kwargs,
nheads=nheads, nmixs=nmixs, with_global=with_global,
with_enc=with_enc, se_loss=se_loss,
lateral=with_lateral)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
def forward(self, x):
imsize = x.size()[2:]
#_, _, c3, c4 = self.base_forward(x)
#x = list(self.head(c4))
features = self.base_forward(x)
x = list(self.head(*features))
x[0] = interpolate(x[0], imsize, **self._up_kwargs)
if self.aux:
#auxout = self.auxlayer(c3)
auxout = self.auxlayer(features[2])
auxout = interpolate(auxout, imsize, **self._up_kwargs)
x.append(auxout)
return tuple(x)
def demo(self, x):
imsize = x.size()[2:]
features = self.base_forward(x)
return self.head.demo(*features)
class GlobalPooling(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs):
super(GlobalPooling, self).__init__()
self._up_kwargs = up_kwargs
self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
nn.ReLU(True))
def forward(self, x):
_, _, h, w = x.size()
pool = self.gap(x)
return interpolate(pool, (h,w), **self._up_kwargs)
class ATTENHead(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs,
nheads, nmixs, with_global,
with_enc, se_loss, lateral):
super(ATTENHead, self).__init__()
self.with_enc = with_enc
self.se_loss = se_loss
self._up_kwargs = up_kwargs
inter_channels = in_channels // 4
self.lateral = lateral
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU())
if lateral:
self.connect = nn.ModuleList([
nn.Sequential(
nn.Conv2d(512, 512, kernel_size=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True)),
nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True)),
])
self.fusion = nn.Sequential(
nn.Conv2d(3*512, 512, kernel_size=3, padding=1, bias=False),
norm_layer(512),
nn.ReLU(inplace=True))
extended_channels = 0
self.atten = ACFModule(nheads, nmixs, inter_channels, inter_channels//nheads*nmixs,
inter_channels//nheads, norm_layer)
if with_global:
extended_channels = inter_channels
self.atten_layers = ConcurrentModule([
GlobalPooling(inter_channels, extended_channels, norm_layer, self._up_kwargs),
self.atten,
#nn.Sequential(*atten),
])
else:
self.atten_layers = nn.Sequential(*atten)
if with_enc:
self.encmodule = EncModule(inter_channels+extended_channels, out_channels, ncodes=32,
se_loss=se_loss, norm_layer=norm_layer)
self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False),
nn.Conv2d(inter_channels+extended_channels, out_channels, 1))
def forward(self, *inputs):
feat = self.conv5(inputs[-1])
if self.lateral:
c2 = self.connect[0](inputs[1])
c3 = self.connect[1](inputs[2])
feat = self.fusion(torch.cat([feat, c2, c3], 1))
feat = self.atten_layers(feat)
if self.with_enc:
outs = list(self.encmodule(feat))
else:
outs = [feat]
outs[0] = self.conv6(outs[0])
return tuple(outs)
def demo(self, *inputs):
feat = self.conv5(inputs[-1])
if self.lateral:
c2 = self.connect[0](inputs[1])
c3 = self.connect[1](inputs[2])
feat = self.fusion(torch.cat([feat, c2, c3], 1))
attn = self.atten.demo(feat)
return attn
def get_atten(dataset='pascal_voc', backbone='resnet50', pretrained=False,
root='~/.encoding/models', **kwargs):
r"""ATTEN model from the paper `"Fully Convolutional Network for semantic segmentation"
<https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_atten.pdf>`_
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k)
pretrained : bool, default False
Whether to load the pretrained weights for model.
pooling_mode : str, default 'avg'
Using 'max' pool or 'avg' pool in the Attention module.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_atten(dataset='pascal_voc', backbone='resnet50', pretrained=False)
>>> print(model)
"""
# infer number of classes
from ..datasets import datasets, acronyms
model = ATTEN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('atten_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2018
###########################################################################
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from .syncbn import SyncBatchNorm
__all__ = ['ACFModule', 'MixtureOfSoftMaxACF']
class ACFModule(nn.Module):
""" Multi-Head Attention module """
def __init__(self, n_head, n_mix, d_model, d_k, d_v, norm_layer=SyncBatchNorm,
kq_transform='conv', value_transform='conv',
pooling=True, concat=False, dropout=0.1):
super(ACFModule, self).__init__()
self.n_head = n_head
self.n_mix = n_mix
self.d_k = d_k
self.d_v = d_v
self.pooling = pooling
self.concat = concat
if self.pooling:
self.pool = nn.AvgPool2d(3, 2, 1, count_include_pad=False)
if kq_transform == 'conv':
self.conv_qs = nn.Conv2d(d_model, n_head*d_k, 1)
nn.init.normal_(self.conv_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
elif kq_transform == 'ffn':
self.conv_qs = nn.Sequential(
nn.Conv2d(d_model, n_head*d_k, 3, padding=1, bias=False),
norm_layer(n_head*d_k),
nn.ReLU(True),
nn.Conv2d(n_head*d_k, n_head*d_k, 1),
)
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
elif kq_transform == 'dffn':
self.conv_qs = nn.Sequential(
nn.Conv2d(d_model, n_head*d_k, 3, padding=4, dilation=4, bias=False),
norm_layer(n_head*d_k),
nn.ReLU(True),
nn.Conv2d(n_head*d_k, n_head*d_k, 1),
)
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
else:
raise NotImplemented
#self.conv_ks = nn.Conv2d(d_model, n_head*d_k, 1)
self.conv_ks = self.conv_qs
if value_transform == 'conv':
self.conv_vs = nn.Conv2d(d_model, n_head*d_v, 1)
else:
raise NotImplemented
#nn.init.normal_(self.conv_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.conv_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.attention = MixtureOfSoftMaxACF(n_mix=n_mix, d_k=d_k)
self.conv = nn.Conv2d(n_head*d_v, d_model, 1, bias=False)
self.norm_layer = norm_layer(d_model)
def forward(self, x):
residual = x
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
b_, c_, h_, w_ = x.size()
if self.pooling:
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
else:
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
qt = kt
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)
output, attn = self.attention(qt, kt, vt)
output = output.transpose(1, 2).contiguous().view(b_, n_head*d_v, h_, w_)
output = self.conv(output)
if self.concat:
output = torch.cat((self.norm_layer(output), residual), 1)
else:
output = self.norm_layer(output) + residual
return output
def demo(self, x):
residual = x
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
b_, c_, h_, w_ = x.size()
if self.pooling:
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
else:
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
qt = kt
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)
_, attn = self.attention(qt, kt, vt)
attn.view(b_, n_head, h_*w_, -1)
return attn
def extra_repr(self):
return 'n_head={}, n_mix={}, d_k={}, pooling={}' \
.format(self.n_head, self.n_mix, self.d_k, self.pooling)
class MixtureOfSoftMaxACF(nn.Module):
""""Mixture of SoftMax"""
def __init__(self, n_mix, d_k, attn_dropout=0.1):
super(MixtureOfSoftMaxACF, self).__init__()
self.temperature = np.power(d_k, 0.5)
self.n_mix = n_mix
self.att_drop = attn_dropout
self.dropout = nn.Dropout(attn_dropout)
self.softmax1 = nn.Softmax(dim=1)
self.softmax2 = nn.Softmax(dim=2)
self.d_k = d_k
if n_mix > 1:
self.weight = nn.Parameter(torch.Tensor(n_mix, d_k))
std = np.power(n_mix, -0.5)
self.weight.data.uniform_(-std, std)
def forward(self, qt, kt, vt):
B, d_k, N = qt.size()
m = self.n_mix
assert d_k == self.d_k
d = d_k // m
if m > 1:
# \bar{v} \in R^{B, d_k, 1}
bar_qt = torch.mean(qt, 2, True)
# pi \in R^{B, m, 1}
pi = self.softmax1(torch.matmul(self.weight, bar_qt)).view(B*m, 1, 1)
# reshape for n_mix
q = qt.view(B*m, d, N).transpose(1, 2)
N2 = kt.size(2)
kt = kt.view(B*m, d, N2)
v = vt.transpose(1, 2)
# {Bm, N, N}
attn = torch.bmm(q, kt)
attn = attn / self.temperature
attn = self.softmax2(attn)
attn = self.dropout(attn)
if m > 1:
# attn \in R^{Bm, N, N2} => R^{B, N, N2}
attn = (attn * pi).view(B, m, N, N2).sum(1)
output = torch.bmm(attn, v)
return output, attn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment