Last active
April 17, 2016 14:33
-
-
Save andreaskoepf/0412f4a5bfe0531f226071b34802376f to your computer and use it in GitHub Desktop.
Alternative implementation of torch.chunk() that always returns exactly nChunks.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ | |
torch.chunk2() | |
returns a table with nChunk entries, even in the case that the tensor has < nChunk entries in the specified dimension. | |
Behaviour of the originial torch.chunk() function: | |
th> torch.rand(11):chunk(5) | |
{ | |
1 : DoubleTensor - size: 3 | |
2 : DoubleTensor - size: 3 | |
3 : DoubleTensor - size: 3 | |
4 : DoubleTensor - size: 2 | |
} | |
th> torch.rand(10):chunk(7) | |
{ | |
1 : DoubleTensor - size: 2 | |
2 : DoubleTensor - size: 2 | |
3 : DoubleTensor - size: 2 | |
4 : DoubleTensor - size: 2 | |
5 : DoubleTensor - size: 2 | |
} | |
th> torch.rand(5):chunk(10) | |
{ | |
1 : DoubleTensor - size: 1 | |
2 : DoubleTensor - size: 1 | |
3 : DoubleTensor - size: 1 | |
4 : DoubleTensor - size: 1 | |
5 : DoubleTensor - size: 1 | |
} | |
torch.chunk2() behaviour: | |
th> torch.rand(11):chunk(5) | |
{ | |
1 : DoubleTensor - size: 3 | |
2 : DoubleTensor - size: 2 | |
3 : DoubleTensor - size: 2 | |
4 : DoubleTensor - size: 2 | |
5 : DoubleTensor - size: 2 | |
} | |
th> torch.rand(10):chunk(7) | |
{ | |
1 : DoubleTensor - size: 2 | |
2 : DoubleTensor - size: 1 | |
3 : DoubleTensor - size: 2 | |
4 : DoubleTensor - size: 1 | |
5 : DoubleTensor - size: 2 | |
6 : DoubleTensor - size: 1 | |
7 : DoubleTensor - size: 1 | |
} | |
th> torch.rand(5):chunk(10) | |
{ | |
1 : DoubleTensor - size: 1 | |
2 : DoubleTensor - empty | |
3 : DoubleTensor - size: 1 | |
4 : DoubleTensor - empty | |
5 : DoubleTensor - size: 1 | |
6 : DoubleTensor - empty | |
7 : DoubleTensor - size: 1 | |
8 : DoubleTensor - empty | |
9 : DoubleTensor - size: 1 | |
10 : DoubleTensor - empty | |
} | |
]] | |
function Tensor.chunk2(result, tensor, nChunk, dim) | |
if torch.type(result) ~= 'table' then | |
dim = nChunk | |
nChunk = tensor | |
tensor = result | |
result = {} | |
end | |
dim = dim or 1 | |
local n = tensor:size(dim) | |
local lo = 0 | |
local z -- empty tensor variable used when nChunk > tensor:size(dim) | |
for i=1,nChunk do -- (loop is skipped if nChunk < 1, returning empty table) | |
local hi = math.min(math.ceil(i*n/nChunk), n) | |
if lo < hi then | |
table.insert(result, tensor:narrow(dim, lo+1, hi-lo)) | |
lo = hi | |
else | |
z = z or torch.Tensor():typeAs(tensor) | |
table.insert(result, z) | |
end | |
end | |
return result | |
end | |
torch.chunk2 = Tensor.chunk2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment