Skip to content

Instantly share code, notes, and snippets.

@jcreinhold
Created July 15, 2020 14:53
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save jcreinhold/05ac75d7242201fa1e758cc0a581b466 to your computer and use it in GitHub Desktop.
Grid Attention Block in PyTorch
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
grid attention blocks for gated attention networks
Based on: https://github.com/ozan-oktay/Attention-Gated-Networks
Author: Jacob Reinhold (jacob.reinhold@jhu.edu)
"""
__all__ = ['GridAttentionBlock2d',
'GridAttentionBlock3d']
from typing import *
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
ACTIVATION = nn.GELU
class GridAttentionBlock(nn.Module):
_conv = None
_norm = None
_upsample = None
def __init__(self, in_channels:int, gating_channels:int, inter_channels:Optional[int]=None):
super().__init__()
if inter_channels is None:
inter_channels = in_channels
self.W = nn.Sequential(
self._conv(in_channels, in_channels, 1),
self._norm(in_channels),
ACTIVATION()
)
self.theta = self._conv(in_channels, inter_channels, 2, stride=2, bias=False)
self.phi = self._conv(gating_channels, inter_channels, 1)
self.psi = self._conv(inter_channels, 1, 1)
def _interp(self, x:Tensor, size:List[int]) -> Tensor:
return F.interpolate(x, size=size, mode=self._upsample, align_corners=True)
def forward(self, x:Tensor, g:Tensor) -> Tensor:
input_size = x.shape[2:]
theta_x = self.theta(x)
theta_x_size = theta_x.shape[2:]
phi_g = self.phi(g)
phi_g = self._interp(phi_g, theta_x_size)
theta_phi_sum = theta_x + phi_g
f = F.relu(theta_phi_sum, inplace=True)
psi_f = self.psi(f)
psi_f = torch.sigmoid(psi_f)
psi_f = self._interp(psi_f, input_size)
y = psi_f * x
W_y = self.W(y)
return W_y
class GridAttentionBlock3d(GridAttentionBlock):
_conv = nn.Conv3d
_norm = nn.BatchNorm3d
_upsample = "trilinear"
class GridAttentionBlock2d(GridAttentionBlock):
_conv = nn.Conv2d
_norm = nn.BatchNorm2d
_upsample = "bilinear"
if __name__ == "__main__":
attention_block = GridAttentionBlock3d(1,1)
x = torch.randn(2,1,32,32,32)
g = torch.randn(2,1,16,16,16)
y = attention_block(x, g)
assert x.shape == y.shape
@josselineperdomo
Copy link

josselineperdomo commented Jul 31, 2020

Hello @jcreinhold

I would like to know why phi_g is followed by an upsample step? I checked the original code but I don't know why this step could be needed because phi_g and theta_x before the upsampling already have the same size due to the theta stridden convolution applied to x.

Thanks for this implementation :)

@jcreinhold
Copy link
Author

Hi @josselineperdomo

Good question. I wanted to faithfully replicate the original code, so I didn't explore this too much. I can't speak for the original authors, but my guess is that there are scenarios where there is an off-by-one size mismatch due to the stride=2 convolution.

For example, this minor change requires the interpolation step (I changed one dimension to 31 instead of 32):

attention_block = GridAttentionBlock3d(1,1)
x = torch.randn(2,1,32,31,32)
g = torch.randn(2,1,16,16,16)
y = attention_block(x, g)

More generally, I suppose, x and g don't need to be sized in multiples of two of one another. You could imagine using this module in other situations outside of the proposed model where x and g are arbitrary sizes. In this case you would need the interpolation step to make sure that you can add theta_x and phi_g.

@josselineperdomo
Copy link

Yes, you are right, it would be needed if the size of x is not the size of g by 2.
Thanks for your reply @jcreinhold!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment