Skip to content

Instantly share code, notes, and snippets.

@chunchet-ng
Created October 27, 2022 05:05
Show Gist options
  • Save chunchet-ng/2292b6f9cb0954b503b957f82bbf4a04 to your computer and use it in GitHub Desktop.
Save chunchet-ng/2292b6f9cb0954b503b957f82bbf4a04 to your computer and use it in GitHub Desktop.
Polarized Self-Attention - Parallel Variant
import torch
import torch.nn as nn
import os
import random
import numpy as np
from tqdm import tqdm
def set_seed(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
set_seed(0)
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
set_seed(0)
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
set_seed(0)
nn.init.constant_(module.bias, bias)
class PSA_p(nn.Module):
def __init__(self, reset, batch_size, channel, kernel_size=1, stride=1):
super(PSA_p, self).__init__()
self.batch_size = batch_size
self.channel = channel
self.inter_channel = channel // 2
self.kernel_size = kernel_size
self.stride = stride
self.reset = reset
set_seed(0)
self.conv_v_left = nn.Conv2d(self.channel, self.inter_channel, kernel_size=1, stride=stride, padding=0, bias=False)
set_seed(0)
self.conv_q_left = nn.Conv2d(self.channel, 1, kernel_size=1, stride=stride, padding=0, bias=False)
set_seed(0)
self.conv_up = nn.Conv2d(self.inter_channel, self.channel, kernel_size=1, stride=1, padding=0, bias=False)
set_seed(0)
self.ln=nn.LayerNorm([self.batch_size, self.channel, 1, 1])
set_seed(0)
self.softmax_left = nn.Softmax(dim=2)
set_seed(0)
self.sigmoid = nn.Sigmoid()
set_seed(0)
self.conv_q_right = nn.Conv2d(self.channel, self.inter_channel, kernel_size=1, stride=stride, padding=0, bias=False) #g
set_seed(0)
self.conv_v_right = nn.Conv2d(self.channel, self.inter_channel, kernel_size=1, stride=stride, padding=0, bias=False) #theta
set_seed(0)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
set_seed(0)
self.softmax_right = nn.Softmax(dim=2)
if self.reset:
self.reset_parameters()
def reset_parameters(self):
set_seed(0)
kaiming_init(self.conv_q_left, mode='fan_in')
set_seed(0)
kaiming_init(self.conv_v_left, mode='fan_in')
set_seed(0)
kaiming_init(self.conv_q_right, mode='fan_in')
set_seed(0)
kaiming_init(self.conv_v_right, mode='fan_in')
self.conv_q_left.inited = True
self.conv_v_left.inited = True
self.conv_q_right.inited = True
self.conv_v_right.inited = True
def channel_pool(self, x):
input_x = self.conv_v_left(x)
batch, channel, height, width = input_x.size()
input_x = input_x.view(batch, channel, height * width) # [N, IC, H*W]
context_mask = self.conv_q_left(x) # [N, 1, H, W]
context_mask = context_mask.view(batch, 1, height * width) # [N, 1, H*W]
context_mask = self.softmax_left(context_mask) # [N, 1, H*W]
context = torch.matmul(input_x, context_mask.transpose(1,2)) # [N, IC, 1]
context = context.unsqueeze(-1) # [N, IC, 1, 1]
context = self.conv_up(context) # [N, OC, 1, 1]
# context = self.ln(context.reshape(batch,self.channel,1).permute(0,2,1)) # [N, OC, 1, 1]
context = self.ln(context)
mask_ch = self.sigmoid(context) # [N, OC, 1, 1]
# out = x * mask_ch.permute(0,2,1).reshape(batch,self.channel,1,1)
out = x * mask_ch
return out
def spatial_pool(self, x):
g_x = self.conv_q_right(x) # [N, IC, H, W]
batch, channel, height, width = g_x.size()
avg_x = self.avg_pool(g_x) # [N, IC, 1, 1]
batch, channel, avg_x_h, avg_x_w = avg_x.size()
avg_x = avg_x.view(batch, channel, avg_x_h * avg_x_w).permute(0, 2, 1) # [N, 1, IC]
avg_x = self.softmax_right(avg_x) # [N, 1, IC]
theta_x = self.conv_v_right(x).view(batch, self.inter_channel, \
height * width) # [N, IC, H*W]
context = torch.matmul(avg_x, theta_x) # [N, 1, H*W]
context = context.view(batch, 1, height, width) # [N, 1, H, W]
mask_sp = self.sigmoid(context) # [N, 1, H, W]
out = x * mask_sp
return out
def forward(self, x):
# [N, C, H, W]
context_channel = self.channel_pool(x)
# [N, C, H, W]
context_spatial = self.spatial_pool(x)
# [N, C, H, W]
out = context_spatial + context_channel
return context_channel, context_spatial, out
class ParallelPolarizedSelfAttention(nn.Module):
def __init__(self, reset, channel=512):
super(ParallelPolarizedSelfAttention, self).__init__()
self.reset = reset
set_seed(0)
self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1), bias=False)
set_seed(0)
self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1), bias=False)
set_seed(0)
self.softmax_channel=nn.Softmax(1)
set_seed(0)
self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1), bias=False)
set_seed(0)
self.ln=nn.LayerNorm(channel)
set_seed(0)
self.sigmoid=nn.Sigmoid()
set_seed(0)
self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1), bias=False)
set_seed(0)
self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1), bias=False)
set_seed(0)
self.agp=nn.AdaptiveAvgPool2d((1,1))
set_seed(0)
self.softmax_spatial=nn.Softmax(-1)
if self.reset:
self.reset_parameters()
def reset_parameters(self):
set_seed(0)
kaiming_init(self.ch_wq, mode='fan_in')
set_seed(0)
kaiming_init(self.ch_wv, mode='fan_in')
set_seed(0)
kaiming_init(self.sp_wv, mode='fan_in')
set_seed(0)
kaiming_init(self.sp_wv, mode='fan_in')
self.ch_wq.inited = True
self.ch_wv.inited = True
self.sp_wv.inited = True
self.sp_wv.inited = True
def channel_pool(self, x):
b, c, h, w = x.size()
# Channel-only Self-Attention
channel_wv=self.ch_wv(x) # bs,c//2,h,w
channel_wq=self.ch_wq(x) # bs,1,h,w
channel_wv=channel_wv.reshape(b,c//2,-1) # bs,c//2,h*w
channel_wq=channel_wq.reshape(b,-1,1) # bs,h*w,1
channel_wq=self.softmax_channel(channel_wq)
channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) # bs,c//2,1,1
channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) # bs,c,1,1
channel_out=channel_weight*x
return channel_out
def spatial_pool(self, x):
b, c, h, w = x.size()
# Spatial-only Self-Attention
spatial_wq=self.sp_wq(x) # bs,c//2,h,w
spatial_wq=self.agp(spatial_wq) # bs,c//2,1,1
spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) # bs,1,c//2
spatial_wq=self.softmax_spatial(spatial_wq)
spatial_wv=self.sp_wv(x) # bs,c//2,h,w
spatial_wv=spatial_wv.reshape(b,c//2,-1) # bs,c//2,h*w
spatial_wz=torch.matmul(spatial_wq,spatial_wv) # bs,1,h*w
spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) # bs,1,h,w
spatial_out=spatial_weight*x
return spatial_out
def forward(self, x):
channel_out = self.channel_pool(x)
spatial_out = self.spatial_pool(x)
out = spatial_out + channel_out
return channel_out, spatial_out, out
if __name__ == '__main__':
import functools
assert_equal = functools.partial(torch.testing.assert_close, rtol=0.05, atol=0.05)
channel = 4
batch_size = 1
error_count = 0
reset = False
print(f'Testing with reset = {reset}')
for i in tqdm(range(100)):
try:
input = torch.randn(batch_size, channel, 7, 7).float()
psa = ParallelPolarizedSelfAttention(reset=reset, channel=channel)
channel_out, spatial_out, out = psa(input)
psa_2 = PSA_p(reset=reset, batch_size=batch_size, channel=channel)
context_channel, context_spatial, out2 = psa_2(input)
assert_equal(spatial_out, context_spatial)
assert_equal(channel_out, context_channel)
assert_equal(out, out2)
except Exception:
error_count += 1
print(f'Error rate: {error_count}/100 [{round(error_count/100, 2)*100}%]')
error_count = 0
reset = True
print(f'Testing with reset = {reset}')
for i in tqdm(range(100)):
try:
input = torch.randn(batch_size, channel, 7, 7).float()
psa = ParallelPolarizedSelfAttention(reset=reset, channel=channel)
channel_out, spatial_out, out = psa(input)
psa_2 = PSA_p(reset=reset, batch_size=batch_size, channel=channel)
context_channel, context_spatial, out2 = psa_2(input)
assert_equal(spatial_out, context_spatial)
assert_equal(channel_out, context_channel)
assert_equal(out, out2)
except Exception:
error_count += 1
print(f'Error rate: {error_count}/100 [{round(error_count/100, 2)*100}%]')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment