Skip to content

Instantly share code, notes, and snippets.

@jimmy15923
Last active December 8, 2020 03:20
Show Gist options
  • Save jimmy15923/2f5933156a2a92de0258cd6145bbb3c1 to your computer and use it in GitHub Desktop.
Save jimmy15923/2f5933156a2a92de0258cd6145bbb3c1 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock1D_mutual(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=False, bn_layer=True):
super(NONLocalBlock1D_mutual, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, sub_sample=sub_sample,
bn_layer=bn_layer)
if bn_layer:
self.W = nn.Sequential(
nn.Conv1d(in_channels=self.inter_channels*2, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm1d(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = nn.Conv1d(in_channels=self.inter_channels*2, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
n_point = x.size(2)
data_1 = x[0:1] # Serve as key
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(data_1).view(1, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
fg_attention = torch.mean(f_div_C, dim=1)[:,:,None].repeat(1, 1, self.inter_channels)
bg_attention = 1 - fg_attention
fg_attention_features = torch.mul(fg_attention, g_x)
bg_attention_features = torch.mul(bg_attention, g_x)
y = torch.cat([fg_attention_features, bg_attention_features], dim=2)
y = y.permute(0, 2, 1).contiguous()
W_y = self.W(y)
z = W_y + x
return z
if __name__ == '__main__':
img = torch.zeros(2, 3, 20)
net = NONLocalBlock1D_mutual(3, sub_sample=False, bn_layer=True)
out = net(img)
print(out.size())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment