Created October 18, 2020 09:50
Transformer for convlution, replacing SE module
import torch
import math
from torch.nn import functional as F
from torch import nn
class PositionalEncoding(nn.Module):
def __init__(self, d_model, height, width, dropout=0.1):
self.dropout = nn.Dropout(p=dropout)
pe = self.positionalencoding2d(d_model, height, width)
self.register_buffer('pe', pe)
def positionalencoding2d(self, d_model, height, width):
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model))
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0., d_model, 2) *
-(math.log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
return pe.unsqueeze(0)
def forward(self, x):
pe = torch.flatten(, start_dim=2).transpose(1,2)
x = x + pe
return self.dropout(x)
class TransLayer(nn.Module):
def __init__(self,H,W, channels, nhead=8):
self.ffn_in = nn.Linear(channels, d_model)
self.ffn_out = nn.Linear(d_model, channels) = PositionalEncoding(d_model, H, W)
self.tans = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, activation="gelu")
self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True)
def forward(self, x):
# B,C,H,W
x = x.contiguous()
shape_x = x.shape
residual = x
# B,C,H,W -> B,C,H*W -> B,H*W,C
x = torch.flatten(x, start_dim=2).transpose(1,2)
# transformer
x = self.ffn_in(x)
x =
x = self.tans(x)
x = self.ffn_out(x)
# B,H*W,C -> B,C,H*W -> B,C,H,W
x = x.transpose(1,2).view(shape_x)
x = residual + self.alpha * x
return x
if __name__ == "__main__":
H = 3
W = H + 1
a = TransLayer(H,W,1)
