Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active May 18, 2023 15:58
Show Gist options
  • Save vadimkantorov/ba3d3d099e403f2703facd2765625f27 to your computer and use it in GitHub Desktop.
Save vadimkantorov/ba3d3d099e403f2703facd2765625f27 to your computer and use it in GitHub Desktop.
Low Rank Bilinear Pooling implementation in PyTorch
# Hadamard Product for Low-Rank Bilinear Pooling, Kim et al., https://arxiv.org/abs/1610.04325
# Original implementation in LuaTorch: https://github.com/jnhwkim/MulLowBiVQA
import torch
class LowRankBilinearPooling(torch.nn.Module):
def __init__(self, in_channels1, in_channels2, hidden_dim, out_channels, nonlinearity = torch.nn.Identity, sum_pool = True):
super().__init__()
self.nonlinearity = nonlinearity
self.sum_pool = sum_pool
self.proj1 = nn.Linear(in_channels1, hidden_dim, bias = False)
self.proj2 = nn.Linear(in_channels2, hidden_dim, bias = False)
self.proj = nn.Linear(hidden_dim, out_channels)
def forward(self, x1, x2):
x1_ = self.nonlinearity(self.proj1(x1))
x2_ = self.nonlinearity(self.proj2(x2))
lrbp = self.proj(x1_.unsqueeze(-2) * x2_.unsqueeze(1))
return lrbp.sum(dim = (1, 2)) if self.sum_pool else lrbp
if __name__ == '__main__':
batch_size = 16
spatial_dim = 32
hidden_dim1 = 64
hidden_dim2 = 64
hidden_dim = 128
output_dim = 256
# currently supports just 1 spatial dim
x1 = torch.rand(batch_size, spatial_dim, hidden_dim1)
x2 = torch.rand(batch_size, spatial_dim, hidden_dim2)
lrbp = LowRankBilinearPooling(hidden_dim1, hidden_dim2, hidden_dim, output_dim)
y = lrbp(x1, x2)
print(y.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment