Skip to content

Instantly share code, notes, and snippets.

@szagoruyko
Created June 11, 2015 13:43
Show Gist options
  • Save szagoruyko/6da251c360340cd3c48a to your computer and use it in GitHub Desktop.
Save szagoruyko/6da251c360340cd3c48a to your computer and use it in GitHub Desktop.
BidirectionalSequencer.lua
local BidirectionalSequencer, parent = torch.class('nn.BidirectionalSequencer', 'nn.Container')
function BidirectionalSequencer:__init(module_forward, module_backward, nOutputSize)
parent.__init(self)
self.module_forward = module_forward
self.module_backward = module_backward
self.modules[1] = nn.Sequencer(module_forward)
self.modules[2] = nn.Sequencer(module_backward)
self.output = {}
self.gradInput = {}
end
function BidirectionalSequencer:updateOutput(input)
local reverse_input = {}
for i,v in ipairs(input) do
reverse_input[#input - i + 1] = v
end
local of = self.modules[1]:updateOutput(input)
local ob = self.modules[2]:updateOutput(reverse_input)
self.of = of
self.ob = ob
local bs = of[1]:size(1)
--if input[1]:nDimension() == 2 then bs = input[1]:size(1) end
local of_ndim = of[1]:size(2)
local ob_ndim = ob[1]:size(2)
for i,v in ipairs(input) do
if not self.output[i] then self.output[i] = v.new() end
self.output[i]:resize(bs,of_ndim + ob_ndim)
self.output[i]:narrow(2,1,of_ndim):copy(of[i])
self.output[i]:narrow(2,of_ndim+1,ob_ndim):copy(ob[i])
end
return self.output
end
function BidirectionalSequencer:updateGradInput(input, gradOutput)
local reverse_input = {}
for i,v in ipairs(input) do
reverse_input[#input - i + 1] = v
end
local of_ndim = self.modules[1].output[1]:size(2)
local ob_ndim = self.modules[2].output[1]:size(2)
local forward_gradOutput = {}
local backward_gradOutput = {}
for i,v in ipairs(gradOutput) do
forward_gradOutput[i] = v:narrow(2,1,of_ndim)
backward_gradOutput[i] = v:narrow(2,of_ndim+1,of_ndim)
end
local forward_gradInput = self.modules[1]:updateGradInput(input, forward_gradOutput)
local backward_gradInput = self.modules[2]:updateGradInput(reverse_input, backward_gradOutput)
for i,v in ipairs(forward_gradInput) do
if not self.gradInput[i] then self.gradInput[i] = input[1].new() end
self.gradInput[i]:resize(#input[i])
self.gradInput[i]:copy(v)
self.gradInput[i]:add(backward_gradInput[#gradOutput - i + 1])
end
return self.gradInput
end
function BidirectionalSequencer:accGradParameters(input, gradOutput, scale)
local reverse_input = {}
for i,v in ipairs(input) do
reverse_input[#input - i + 1] = v
end
local of_ndim = self.modules[1].output[1]:size(2)
local ob_ndim = self.modules[2].output[1]:size(2)
local forward_gradOutput = {}
local backward_gradOutput = {}
for i,v in ipairs(gradOutput) do
forward_gradOutput[i] = v:narrow(2,1,of_ndim)
backward_gradOutput[i] = v:narrow(2,of_ndim+1,of_ndim)
end
self.modules[1]:accGradParameters(input, forward_gradOutput, scale)
self.modules[2]:accGradParameters(reverse_input, backward_gradOutput, scale)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment