Skip to content

Instantly share code, notes, and snippets.

@skaae
Created September 24, 2015 12:57
Show Gist options
  • Save skaae/d9c82bdeb9ce5ba0bf6e to your computer and use it in GitHub Desktop.
Save skaae/d9c82bdeb9ce5ba0bf6e to your computer and use it in GitHub Desktop.
class GruDenseLayer(lasagne.layers.Layer):
def __init__(self, incoming, num_units,
b_resetgate=None,
b_updategate=None,
b_hidden_update=None,
W_resetgate=init.GlorotUniform(),
W_updategate=init.GlorotUniform(),
W_hidden_update=init.GlorotUniform(),
**kwargs):
super(GruDenseLayer, self).__init__(incoming, **kwargs)
self.num_units = num_units
num_inputs = int(np.prod(self.input_shape[1:]))
self.W_resetgate = self.add_param(
W_resetgate, (num_inputs, num_units),
name="W_resetgate")
self.W_updategate = self.add_param(
W_updategate, (num_inputs, num_units),
name="W_updategate")
self.W_hidden_update = self.add_param(
W_hidden_update, (num_inputs, num_units),
name="W_hidden_updategate")
self.W_stacked = T.concatenate(
[self.W_resetgate, self.W_updategate,
self.W_hidden_update], axis=1)
b = [b_resetgate, b_updategate, b_hidden_update]
if all(map(lambda x: x is not None, b)):
self.b_resetgate = self.add_param(
b_resetgate, (num_units, ),
name="b_resetgate")
self.b_updategate = self.add_param(
b_updategate, (num_units, ),
name="b_updategate")
self.b_hidden_update = self.add_param(
b_hidden_update, (num_units, ),
name="b_hidden_update")
self.b_stacked = T.concatenate(
[self.b_resetgate, self.b_updategate, self.b_hidden_update],
axis=0)
else:
self.b_stacked = None
def get_output_shape_for(self, input_shape):
return (input_shape[0], 3*self.num_units)
def get_output_for(self, input, **kwargs):
if input.ndim > 2:
# if the input has more than two dimensions, flatten it into a
# batch of feature vectors.
input = input.flatten(2)
activation = T.dot(input, self.W_stacked)
if self.b_stacked is not None:
activation = activation + self.b_stacked.dimshuffle('x', 0)
return activation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment