Skip to content

Instantly share code, notes, and snippets.

@AjayTalati
Forked from skaae/LSTM.lua
Last active August 29, 2015 14:20
Show Gist options
  • Save AjayTalati/7b6c8012d5d45e0cc9de to your computer and use it in GitHub Desktop.
Save AjayTalati/7b6c8012d5d45e0cc9de to your computer and use it in GitHub Desktop.
--[[
LSTM cell. Modified from
https://github.com/oxford-cs-ml-2015/practical6/blob/master/LSTM.lua
--]]
local LSTM = {}
-- Creates one timestep of one LSTM
function LSTM.lstm(opt)
local x = nn.Identity()()
local prev_c = nn.Identity()()
local prev_h = nn.Identity()()
-- Calculate all four gates in one go
local i2h = nn.Linear(opt.rnn_size, 4*opt.rnn_size)(x)
local h2h = nn.Linear(opt.rnn_size, 4*opt.rnn_size)(prev_h)
local gates = nn.CAddTable()({i2h, h2h})
-- Reshape to (batch_size, n_gates, hid_size)
-- Then slize the n_gates dimension, i.e dimension 2
local reshaped_gates = nn.Reshape(4,opt.rnn_size)(gates)
local sliced_gates = nn.SplitTable(2)(reshaped_gates)
-- Use select gate to fetch each gate and apply nonlinearity
local in_gate = nn.Sigmoid()(nn.SelectTable(1)(sliced_gates))
local in_transform = nn.Tanh()(nn.SelectTable(2)(sliced_gates))
local forget_gate = nn.Sigmoid()(nn.SelectTable(3)(sliced_gates))
local out_gate = nn.Sigmoid()(nn.SelectTable(4)(sliced_gates))
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
return nn.gModule({x, prev_c, prev_h}, {next_c, next_h})
end
return LSTM
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment