Skip to content

Instantly share code, notes, and snippets.

@skaae
Created April 30, 2015 13:45
Show Gist options
  • Save skaae/ea4320e17379d408e693 to your computer and use it in GitHub Desktop.
Save skaae/ea4320e17379d408e693 to your computer and use it in GitHub Desktop.
--[[
LSTM cell. Modified from
https://github.com/oxford-cs-ml-2015/practical6/blob/master/LSTM.lua
--]]
local LSTM = {}
-- Creates one timestep of one LSTM
function LSTM.lstm(opt)
local x = nn.Identity()()
local prev_c = nn.Identity()()
local prev_h = nn.Identity()()
-- Calculate all four gates in one go
local i2h = nn.Linear(opt.rnn_size, 4*opt.rnn_size)(x)
local h2h = nn.Linear(opt.rnn_size, 4*opt.rnn_size)(prev_h)
local gates = nn.CAddTable()({i2h, h2h})
-- Reshape to (batch_size, n_gates, hid_size)
-- Then slize the n_gates dimension, i.e dimension 2
local reshaped_gates = nn.Reshape(4,opt.rnn_size)(gates)
local sliced_gates = nn.SplitTable(2)(reshaped_gates)
-- Use select gate to fetch each gate and apply nonlinearity
local in_gate = nn.Sigmoid()(nn.SelectTable(1)(sliced_gates))
local in_transform = nn.Tanh()(nn.SelectTable(2)(sliced_gates))
local forget_gate = nn.Sigmoid()(nn.SelectTable(3)(sliced_gates))
local out_gate = nn.Sigmoid()(nn.SelectTable(4)(sliced_gates))
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
return nn.gModule({x, prev_c, prev_h}, {next_c, next_h})
end
return LSTM
@AjayTalati
Copy link

Hi Soren,

Thank's a lot for posting this 😄

Just in case you have not seen it yet, Sergey Zagoruyko has implemented your code for Penn Tree bank,

wojzaremba/lstm#12

Regards,

Aj

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment