Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active September 22, 2021 07:51
Show Gist options
  • Save vadimkantorov/d9b56f9b85f1f4ce59ffecf893a1581a to your computer and use it in GitHub Desktop.
Save vadimkantorov/d9b56f9b85f1f4ce59ffecf893a1581a to your computer and use it in GitHub Desktop.
Compact Bilinear Pooling in PyTorch using the new FFT support
# References:
# [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847
# [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062
# [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf
# [4] Fastfood — Approximating Kernel Expansions in Loglinear Time, Le et al., https://arxiv.org/abs/1408.3060
# [5] Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling
# TODO: migrate to use of new native complex64 types
# TODO: change strided x coo matmul to torch.matmul(): M[sparse_coo] @ M[strided] -> M[strided]
import torch
class CompactBilinearPooling(torch.nn.Module):
def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True):
super().__init__()
self.out_channels = out_channels
self.sum_pool = sum_pool
generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_channels), rand_h]), rand_s, [in_channels, out_channels]).to_dense()
self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,), dtype = torch.float32) - 1, in_channels1, out_channels), requires_grad = False)
self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,), dtype = torch.float32) - 1, in_channels2, out_channels), requires_grad = False)
def forward(self, x1, x2):
fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tensor_sketch1), signal_ndim = 1)
fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.tensor_sketch2), signal_ndim = 1)
# torch.rfft does not support yet torch.complex64 outputs, so we do complex product manually
fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.out_channels, )) * self.out_channels
return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2)
@vadimkantorov
Copy link
Author

Just install PyTorch from master branch or even 0.4 version probably has FFT

@ayumiymk
Copy link

Thanks for your code first. I have a question that in the other implements, like Torch version and Tensorflow version, there is a zero_padding before feeding the tensor into the fft. But in this code, I don't see the zero_padding.

Thanks very much!

@hj0921
Copy link

hj0921 commented Mar 19, 2021

hello,

torch.stack([torch.arange(in_features), rand_h]) where in_features is not defined. How to fix it?

thanks!

@vadimkantorov
Copy link
Author

Thanks for noting this. Fixed! It should have been in_channels

@vadimkantorov
Copy link
Author

Some ways to improve the code: make use of the new PyTorch fft module, complex support. Figure out dense x sparse matmul (currently I'm materializing the sparse sketch)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment