Skip to content

Instantly share code, notes, and snippets.

@albanD
Created August 21, 2015 13:13
Show Gist options
  • Save albanD/954021a4be9e1ccab753 to your computer and use it in GitHub Desktop.
Save albanD/954021a4be9e1ccab753 to your computer and use it in GitHub Desktop.
require 'nn'
require 'stn'
------
-- Prepare your localization network
local localization_network = torch.load('your_locnet.t7')
------
-- prepare both branches of the st
local ct = nn.ConcatTable()
-- This branch does not modify the input, just change the data layout to bhwd
local branch1 = nn.Transpose({3,4},{2,4})
-- This branch will compute the parameters and generate the grid
local branch2 = nn.Sequential()
branch2:add(localization_network)
-- Here you can restrict the possible transformation with the "use_*" boolean variables
branch2:add(nn.AffineTransformMatrixGenerator(use_rot, use_sca, use_tra))
branch2:add(nn.AffineGridGeneratorBHWD(input_size, input_size))
ct:add(branch1)
ct:add(branch2)
------
-- Wrap the st in one module
local st_module = nn.Sequential()
st_module:add(ct)
st_module:add(nn.BilinearSamplerBHWD())
-- go back to the bdhw layout (used by all default torch modules)
st_module:add(nn.Transpose({2,4},{3,4}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment