Last active
November 10, 2016 06:32
-
-
Save vadimkantorov/83b06d470c9bdf67a89c15199d21e90d to your computer and use it in GitHub Desktop.
DynamicView module for Torch that views an input tensor dynamically at runtime with user-provided function of its size. See an example of total variation computation below.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function DynamicView(getSizeTable) | |
local module = nn.View(-1) | |
module.updateOutput = function(self, input) return nn.View.updateOutput(self:resetSize(unpack(getSizeTable(input:size()))), input) end | |
return module | |
end | |
-- function TotalVariation() --accepts 4D tensors | |
-- local sk = torch.Tensor(2, 1, 2,2) | |
-- sk[1][1]:copy(torch.Tensor(2,2):set(torch.Storage({-1, 1, 0, 0}))) | |
-- sk[2][1]:copy(torch.Tensor(2,2):set(torch.Storage({-1, 0, 1, 0}))) | |
-- local conv = cudnn.SpatialConvolution(1,2, 2,2, 1, 1,0, 0):noBias() | |
-- conv.weight:copy(sk) | |
-- local batchSize | |
-- return nn.Sequential() | |
-- :add(DynamicView(function(sz) batchSize = sz[1]; return {sz[1] * sz[2], 1, sz[3], sz[4]} end)) | |
-- :add(conv) | |
-- :add(nn.Abs()) | |
-- :add(DynamicView(function(sz) return {batchSize, -1} end)) | |
-- :add(nn.Mean(2)) | |
-- end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment