Skip to content

Instantly share code, notes, and snippets.

@andreaskoepf
Last active April 17, 2016 14:33
Show Gist options
  • Save andreaskoepf/0412f4a5bfe0531f226071b34802376f to your computer and use it in GitHub Desktop.
Save andreaskoepf/0412f4a5bfe0531f226071b34802376f to your computer and use it in GitHub Desktop.
Alternative implementation of torch.chunk() that always returns exactly nChunks.
--[[
torch.chunk2()
returns a table with nChunk entries, even in the case that the tensor has < nChunk entries in the specified dimension.
Behaviour of the originial torch.chunk() function:
th> torch.rand(11):chunk(5)
{
1 : DoubleTensor - size: 3
2 : DoubleTensor - size: 3
3 : DoubleTensor - size: 3
4 : DoubleTensor - size: 2
}
th> torch.rand(10):chunk(7)
{
1 : DoubleTensor - size: 2
2 : DoubleTensor - size: 2
3 : DoubleTensor - size: 2
4 : DoubleTensor - size: 2
5 : DoubleTensor - size: 2
}
th> torch.rand(5):chunk(10)
{
1 : DoubleTensor - size: 1
2 : DoubleTensor - size: 1
3 : DoubleTensor - size: 1
4 : DoubleTensor - size: 1
5 : DoubleTensor - size: 1
}
torch.chunk2() behaviour:
th> torch.rand(11):chunk(5)
{
1 : DoubleTensor - size: 3
2 : DoubleTensor - size: 2
3 : DoubleTensor - size: 2
4 : DoubleTensor - size: 2
5 : DoubleTensor - size: 2
}
th> torch.rand(10):chunk(7)
{
1 : DoubleTensor - size: 2
2 : DoubleTensor - size: 1
3 : DoubleTensor - size: 2
4 : DoubleTensor - size: 1
5 : DoubleTensor - size: 2
6 : DoubleTensor - size: 1
7 : DoubleTensor - size: 1
}
th> torch.rand(5):chunk(10)
{
1 : DoubleTensor - size: 1
2 : DoubleTensor - empty
3 : DoubleTensor - size: 1
4 : DoubleTensor - empty
5 : DoubleTensor - size: 1
6 : DoubleTensor - empty
7 : DoubleTensor - size: 1
8 : DoubleTensor - empty
9 : DoubleTensor - size: 1
10 : DoubleTensor - empty
}
]]
function Tensor.chunk2(result, tensor, nChunk, dim)
if torch.type(result) ~= 'table' then
dim = nChunk
nChunk = tensor
tensor = result
result = {}
end
dim = dim or 1
local n = tensor:size(dim)
local lo = 0
local z -- empty tensor variable used when nChunk > tensor:size(dim)
for i=1,nChunk do -- (loop is skipped if nChunk < 1, returning empty table)
local hi = math.min(math.ceil(i*n/nChunk), n)
if lo < hi then
table.insert(result, tensor:narrow(dim, lo+1, hi-lo))
lo = hi
else
z = z or torch.Tensor():typeAs(tensor)
table.insert(result, z)
end
end
return result
end
torch.chunk2 = Tensor.chunk2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment