#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
grid attention blocks for gated attention networks | |
Based on: https://github.com/ozan-oktay/Attention-Gated-Networks | |
Author: Jacob Reinhold (jacob.reinhold@jhu.edu) | |
""" | |
__all__ = ['GridAttentionBlock2d', | |
'GridAttentionBlock3d'] | |
from typing import * | |
import torch | |
from torch import Tensor | |
from torch import nn | |
from torch.nn import functional as F | |
ACTIVATION = nn.GELU | |
class GridAttentionBlock(nn.Module): | |
_conv = None | |
_norm = None | |
_upsample = None | |
def __init__(self, in_channels:int, gating_channels:int, inter_channels:Optional[int]=None): | |
super().__init__() | |
if inter_channels is None: | |
inter_channels = in_channels | |
self.W = nn.Sequential( | |
self._conv(in_channels, in_channels, 1), | |
self._norm(in_channels), | |
ACTIVATION() | |
) | |
self.theta = self._conv(in_channels, inter_channels, 2, stride=2, bias=False) | |
self.phi = self._conv(gating_channels, inter_channels, 1) | |
self.psi = self._conv(inter_channels, 1, 1) | |
def _interp(self, x:Tensor, size:List[int]) -> Tensor: | |
return F.interpolate(x, size=size, mode=self._upsample, align_corners=True) | |
def forward(self, x:Tensor, g:Tensor) -> Tensor: | |
input_size = x.shape[2:] | |
theta_x = self.theta(x) | |
theta_x_size = theta_x.shape[2:] | |
phi_g = self.phi(g) | |
phi_g = self._interp(phi_g, theta_x_size) | |
theta_phi_sum = theta_x + phi_g | |
f = F.relu(theta_phi_sum, inplace=True) | |
psi_f = self.psi(f) | |
psi_f = torch.sigmoid(psi_f) | |
psi_f = self._interp(psi_f, input_size) | |
y = psi_f * x | |
W_y = self.W(y) | |
return W_y | |
class GridAttentionBlock3d(GridAttentionBlock): | |
_conv = nn.Conv3d | |
_norm = nn.BatchNorm3d | |
_upsample = "trilinear" | |
class GridAttentionBlock2d(GridAttentionBlock): | |
_conv = nn.Conv2d | |
_norm = nn.BatchNorm2d | |
_upsample = "bilinear" | |
if __name__ == "__main__": | |
attention_block = GridAttentionBlock3d(1,1) | |
x = torch.randn(2,1,32,32,32) | |
g = torch.randn(2,1,16,16,16) | |
y = attention_block(x, g) | |
assert x.shape == y.shape |
Good question. I wanted to faithfully replicate the original code, so I didn't explore this too much. I can't speak for the original authors, but my guess is that there are scenarios where there is an off-by-one size mismatch due to the stride=2 convolution.
For example, this minor change requires the interpolation step (I changed one dimension to 31 instead of 32):
attention_block = GridAttentionBlock3d(1,1)
x = torch.randn(2,1,32,31,32)
g = torch.randn(2,1,16,16,16)
y = attention_block(x, g)
More generally, I suppose, x
and g
don't need to be sized in multiples of two of one another. You could imagine using this module in other situations outside of the proposed model where x
and g
are arbitrary sizes. In this case you would need the interpolation step to make sure that you can add theta_x
and phi_g
.
Yes, you are right, it would be needed if the size of x
is not the size of g
by 2.
Thanks for your reply @jcreinhold!
Hello @jcreinhold
I would like to know why
phi_g
is followed by an upsample step? I checked the original code but I don't know why this step could be needed becausephi_g
andtheta_x
before the upsampling already have the same size due to thetheta
stridden convolution applied tox
.Thanks for this implementation :)