Skip to content

Instantly share code, notes, and snippets.

Created August 23, 2017 22:00
Show Gist options
  • Save hoavt-54/beb79cea7fbb19cbf91e9aeefa168c16 to your computer and use it in GitHub Desktop.
Save hoavt-54/beb79cea7fbb19cbf91e9aeefa168c16 to your computer and use it in GitHub Desktop.
def __call__(self, inputs, ctx, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with vs.variable_scope(scope or "basic_lstm_cell"):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state #[batch_size, hidden_dim]
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
#reshape ctx first since now its shape is [batch_size, num_feats, feat_dim]
ctx_dim = 100
Wd_att = vs.get_variable("Wd_att", [h.get_shape()[1], ctx_dim], dtype=c.dtype)
U_att = vs.get_variable("U_att", [ctx_dim, 1], dtype=c.dtype)
c_att = vs.get_variable("c_att", [1], dtype=c.dtype)
pstate = math_ops.matmul(h, Wd_att) #[batch_size, ctx_dim]
#tile pstate to match ctx shape of [batch_size, num_feats, ctx_dim]
pstate = array_ops.tile(pstate, [1, 49])
pstate = array_ops.reshape(pstate, [-1, ctx_dim])
pstate = pstate + ctx #[batch_size * num_feats, ctx_dim]
pstate = tanh(pstate)
e_ti = math_ops.matmul(pstate, U_att) + c_att #[batch_size * num_feats]
e_ti = array_ops.reshape(e_ti, [-1, 49])
alpha = nn_ops.softmax(logits=e_ti)
alpha = array_ops.tile(alpha, [1, ctx_dim]) #[batch_size, 49 * ctx_dim]
alpha = array_ops.reshape(alpha, [-1, 49, ctx_dim])
ctx = array_ops.reshape(ctx, [-1, 49, ctx_dim])
z = math_ops.multiply(alpha, ctx)
z = math_ops.reduce_sum(z, 1)
print("alpha: ", z.get_shape())
concat = _linear([inputs, h, z], 4 * self._num_units, True, scope=scope)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment