Skip to content

Instantly share code, notes, and snippets.

@lantiga
Created February 6, 2018 08:25
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save lantiga/a78581e6c6c0ad1534065950e204ce9d to your computer and use it in GitHub Desktop.
Save lantiga/a78581e6c6c0ad1534065950e204ce9d to your computer and use it in GitHub Desktop.
Indexed convolution

Indexed convolutions

A convolution operator over a 1D tensor (BxCxL), where a list of neighbors for each element is provided through a indices tensor (LxK), where K is the size of the convolution kernel. Each row of indices specifies the indices of the K neighbors of the corresponding element in the input. A -1 is handled like for zero padding.

Note that the neighbors specified in indices are not relative, but rather absolute. They have to be specified for each of the elements of the output.

A use case is for convolutions over non-square lattices, such as images on hexagonal lattices coming from Cherenkov telescopes (http://www.isdc.unige.ch/%7Elyard/FirstLight/FirstLight_slowHD.mov).

Example:

import torch

# a 1D input of 5 elems
input = torch.randn(1,1,5)

# this specifies the indices of neighbors for
# each elem of the input (a 3 elem kernel here)
# A -1 corresponds to zero-padding 
indices = torch.ones(5,3).type(torch.LongTensor)

weight = torch.randn(1,1,3)
bias = torch.randn(1)

output = torch.nn.functional.indexed_conv(input, indices, weight, bias)
import torch
from torch.autograd import Variable
def prepare_mask(indices):
padded = indices == -1
indices[padded] = 0
mask = torch.FloatTensor([1,0])
mask = mask[..., padded.t().long()]
return indices, mask
def indexed_conv(input, weight, bias, indices, mask):
nbatch = input.shape[0]
output_width = indices.shape[0]
out_chans, in_chans, ksize = weight.shape
if isinstance(input, Variable):
mask = Variable(mask)
col = input[..., indices.t()] * mask
col = col.view(nbatch, -1, output_width)
weight_col = weight.view(out_chans, -1)
out = torch.matmul(weight_col, col) + bias
#print(col)
#print(weight_col)
return out
if __name__ == '__main__':
# input = torch.randn(1,2,5)
# weight = torch.randn(1,2,3)
# bias = torch.randn(1)
# indices = (5 * torch.rand(4,3)).long()
input = torch.ones(1,2,5)
weight = torch.ones(1,2,3)
bias = torch.zeros(1)
indices = (5 * torch.rand(4,3)).long()
indices[0,0] = -1
indices, mask = prepare_mask(indices)
print(input)
print(indices)
out = indexed_conv(input, weight, bias, indices, mask)
input = Variable(input, requires_grad=True)
weight = Variable(weight)
bias = Variable(bias)
out = indexed_conv(input, weight, bias, indices, mask)
print(out)
out.sum().backward()
print(input.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment