Skip to content

Instantly share code, notes, and snippets.

@talolard
Created December 31, 2017 09:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save talolard/8f58dd0a2d36417338cb8054f71ae86b to your computer and use it in GitHub Desktop.
Save talolard/8f58dd0a2d36417338cb8054f71ae86b to your computer and use it in GitHub Desktop.
Example of gru implementation
'''
GRU layer implementation orignally taken from https://github.com/ottokart/punctuator2
'''
class GRULayer(object):
def __init__(self, rng, n_in, n_out, minibatch_size):
super(GRULayer, self).__init__()
# Notation from: An Empirical Exploration of Recurrent Network Architectures
self.n_in = n_in
self.n_out = n_out
# Initial hidden state
self.h0 = theano.shared(value=np.zeros((minibatch_size, n_out)).astype(theano.config.floatX), name='h0', borrow=True)
# Gate parameters:
self.W_x = weights_Glorot(n_in, n_out*2, 'W_x', rng)
self.W_h = weights_Glorot(n_out, n_out*2, 'W_h', rng)
self.b = weights_const(1, n_out*2, 'b', 0)
# Input parameters
self.W_x_h = weights_Glorot(n_in, n_out, 'W_x_h', rng)
self.W_h_h = weights_Glorot(n_out, n_out, 'W_h_h', rng)
self.b_h = weights_const(1, n_out, 'b_h', 0)
self.params = [self.W_x, self.W_h, self.b, self.W_x_h, self.W_h_h, self.b_h]
def step(self, x_t, h_tm1):
rz = T.nnet.sigmoid(T.dot(x_t, self.W_x) + T.dot(h_tm1, self.W_h) + self.b)
r = _slice(rz, self.n_out, 0)
z = _slice(rz, self.n_out, 1)
h = T.tanh(T.dot(x_t, self.W_x_h) + T.dot(h_tm1 * r, self.W_h_h) + self.b_h)
h_t = z * h_tm1 + (1. - z) * h
return h_t
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment