Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Compact Bilinear Pooling in PyTorch using the new FFT support
import torch
class CompactBilinearPooling(torch.nn.Module):
def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
super(CompactBilinearPooling, self).__init__()
self.output_dim = output_dim
self.sum_pool = sum_pool
generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense()
self.sketch1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim), requires_grad = False)
self.sketch2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim), requires_grad = False)
def forward(self, x1, x2):
fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch1), signal_ndim = 1)
fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch2), signal_ndim = 1)
fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
cbp = torch.irfft(fft_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim
return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2)
@SkyFlyboy

This comment has been minimized.

Copy link

commented Apr 20, 2018

Thanks for your code, how to install the new fft support?

@vadimkantorov

This comment has been minimized.

Copy link
Owner Author

commented Jun 5, 2018

Just install PyTorch from master branch or even 0.4 version probably has FFT

@ayumiymk

This comment has been minimized.

Copy link

commented Dec 12, 2018

Thanks for your code first. I have a question that in the other implements, like Torch version and Tensorflow version, there is a zero_padding before feeding the tensor into the fft. But in this code, I don't see the zero_padding.

Thanks very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.