Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Efficient LSTM cell in Torch
--[[
Efficient LSTM in Torch using nngraph library. This code was optimized
by Justin Johnson (@jcjohnson) based on the trick of batching up the
LSTM GEMMs, as also seen in my efficient Python LSTM gist.
--]]
function LSTM.fast_lstm(input_size, rnn_size)
local x = nn.Identity()()
local prev_c = nn.Identity()()
local prev_h = nn.Identity()()
local i2h = nn.Linear(input_size, 4 * rnn_size)(x)
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
local all_input_sums = nn.CAddTable()({i2h, h2h})
local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
in_transform = nn.Tanh()(in_transform)
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
return nn.gModule({x, prev_c, prev_h}, {next_c, next_h})
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.