Last active
March 11, 2019 01:54
-
-
Save hewumars/ea82a62e3e17bd69a034632731ed3fd5 to your computer and use it in GitHub Desktop.
Self-Attention GAN 中的 self-attention 机制 :
self attention 看成是 feature map 和它自身的转置相乘,让任意两个位置的像素直接发生关系,这样就可以学习到任意两个像素之间的依赖关系,从而得到全局特征了。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
classSelf_Attn(nn.Module): | |
""" Self attention Layer""" | |
def__init__(self,in_dim,activation): | |
super(Self_Attn,self).__init__() | |
self.chanel_in = in_dim | |
self.activation = activation | |
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim// 8, kernel_size= 1) | |
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim// 8, kernel_size= 1) | |
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) | |
self.gamma = nn.Parameter(torch.zeros( 1)) | |
self.softmax = nn.Softmax(dim= -1) # | |
def forward(self,x): | |
""" | |
inputs : | |
x : input feature maps( B X C X W X H) | |
returns : | |
out : self attention value + input feature | |
attention: B X N X N (N is Width*Height) | |
""" | |
m_batchsize,C,width ,height = x.size() | |
proj_query = self.query_conv(x).view(m_batchsize, -1,width*height).permute( 0, 2, 1) # B X CX(N) | |
proj_key = self.key_conv(x).view(m_batchsize, -1,width*height) # B X C x (*W*H) | |
energy = torch.bmm(proj_query,proj_key) # transpose check | |
attention = self.softmax(energy) # BX (N) X (N) | |
proj_value = self.value_conv(x).view(m_batchsize, -1,width*height) # B X C X N | |
out = torch.bmm(proj_value,attention.permute( 0, 2, 1) ) | |
out = out.view(m_batchsize,C,width,height) | |
out = self.gamma*out + x | |
return out,attention |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment