-
-
Save mlzxy/60aa666bfe3267fb738b235bddc23e71 to your computer and use it in GitHub Desktop.
Gated Recurrent Unit with Layer norm and Xavier initializer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
class GRU(tf.contrib.rnn.RNNCell): | |
def __init__( | |
self, size, activation=tf.tanh, reuse=None, | |
normalizer=tf.contrib.layers.layer_norm, | |
initializer=tf.contrib.layers.xavier_initializer()): | |
super(GRU, self).__init__(_reuse=reuse) | |
self._size = size | |
self._activation = activation | |
self._normalizer = normalizer | |
self._initializer = initializer | |
@property | |
def state_size(self): | |
return self._size | |
@property | |
def output_size(self): | |
return self._size | |
def call(self, input_, state): | |
update, reset = tf.split(self._forward( | |
'update_reset', [state, input_], 2 * self._size, tf.nn.sigmoid, | |
bias_initializer=tf.constant_initializer(-1.)), 2, 1) | |
candidate = self._forward( | |
'candidate', [reset * state, input_], self._size, self._activation) | |
state = (1 - update) * state + update * candidate | |
return state, state | |
def _forward(self, name, inputs, size, activation, **kwargs): | |
with tf.variable_scope(name): | |
return _forward( | |
inputs, size, activation, normalizer=self._normalizer, | |
weight_initializer=self._initializer, **kwargs) | |
def _forward( | |
inputs, size, activation, normalizer=tf.contrib.layers.layer_norm, | |
weight_initializer=tf.contrib.layers.xavier_initializer(), | |
bias_initializer=tf.zeros_initializer()): | |
if not isinstance(inputs, (tuple, list)): | |
inputs = (inputs,) | |
shapes, outputs = [], [] | |
# Map each input to individually normalize their outputs. | |
for index, input_ in enumerate(inputs): | |
shapes.append(input_.shape[1: -1].as_list()) | |
input_ = tf.contrib.layers.flatten(input_) | |
weight = tf.get_variable( | |
'weight_{}'.format(index + 1), (int(input_.shape[1]), size), | |
tf.float32, weight_initializer) | |
output = tf.matmul(input_, weight) | |
if normalizer: | |
output = normalizer(output) | |
outputs.append(output) | |
output = sum(outputs) | |
# Add bias. | |
bias = tf.get_variable( | |
'weight', (size,), tf.float32, bias_initializer) | |
output += bias | |
# Activation function. | |
if activation: | |
output = activation(output) | |
# Restore shape dimensions that are consistent among inputs. | |
dim = 0 | |
while dim < min(len(shape) for shape in shapes): | |
none = shapes[0].as_list()[dim] | |
equal = all(shape[dim] == shapes[0][dim] for shape in shapes) | |
if none or not equal: | |
break | |
dim += 1 | |
shape = output.shape.as_list()[:1] + shapes[0][:dim] + [-1] | |
output = tf.reshape(output, shape) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment