Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Created October 3, 2017 09:54
Show Gist options
  • Save thomwolf/8f060c027164cf69143f354f692780f1 to your computer and use it in GitHub Desktop.
Save thomwolf/8f060c027164cf69143f354f692780f1 to your computer and use it in GitHub Desktop.
A pyTorch LSTM Cell with a hard sigmoid recurrent activation
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
"""
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates.
"""
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = hard_sigmoid(ingate)
forgetgate = hard_sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = hard_sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
def hard_sigmoid(x):
"""
Computes element-wise hard sigmoid of x.
See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
"""
x = (0.2 * x) + 0.5
x = F.threshold(-x, -1, -1)
x = F.threshold(-x, 0, 0)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment