Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Efficient LSTM cell in Torch
Efficient LSTM in Torch using nngraph library. This code was optimized
by Justin Johnson (@jcjohnson) based on the trick of batching up the
LSTM GEMMs, as also seen in my efficient Python LSTM gist.
function LSTM.fast_lstm(input_size, rnn_size)
local x = nn.Identity()()
local prev_c = nn.Identity()()
local prev_h = nn.Identity()()
local i2h = nn.Linear(input_size, 4 * rnn_size)(x)
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
local all_input_sums = nn.CAddTable()({i2h, h2h})
local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)
local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
in_transform = nn.Tanh()(in_transform)
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
return nn.gModule({x, prev_c, prev_h}, {next_c, next_h})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.