Skip to content

Instantly share code, notes, and snippets.

@mlzxy
Forked from danijar/gru.py
Created November 15, 2017 02:39
Show Gist options
  • Save mlzxy/60aa666bfe3267fb738b235bddc23e71 to your computer and use it in GitHub Desktop.
Save mlzxy/60aa666bfe3267fb738b235bddc23e71 to your computer and use it in GitHub Desktop.
Gated Recurrent Unit with Layer norm and Xavier initializer
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class GRU(tf.contrib.rnn.RNNCell):
def __init__(
self, size, activation=tf.tanh, reuse=None,
normalizer=tf.contrib.layers.layer_norm,
initializer=tf.contrib.layers.xavier_initializer()):
super(GRU, self).__init__(_reuse=reuse)
self._size = size
self._activation = activation
self._normalizer = normalizer
self._initializer = initializer
@property
def state_size(self):
return self._size
@property
def output_size(self):
return self._size
def call(self, input_, state):
update, reset = tf.split(self._forward(
'update_reset', [state, input_], 2 * self._size, tf.nn.sigmoid,
bias_initializer=tf.constant_initializer(-1.)), 2, 1)
candidate = self._forward(
'candidate', [reset * state, input_], self._size, self._activation)
state = (1 - update) * state + update * candidate
return state, state
def _forward(self, name, inputs, size, activation, **kwargs):
with tf.variable_scope(name):
return _forward(
inputs, size, activation, normalizer=self._normalizer,
weight_initializer=self._initializer, **kwargs)
def _forward(
inputs, size, activation, normalizer=tf.contrib.layers.layer_norm,
weight_initializer=tf.contrib.layers.xavier_initializer(),
bias_initializer=tf.zeros_initializer()):
if not isinstance(inputs, (tuple, list)):
inputs = (inputs,)
shapes, outputs = [], []
# Map each input to individually normalize their outputs.
for index, input_ in enumerate(inputs):
shapes.append(input_.shape[1: -1].as_list())
input_ = tf.contrib.layers.flatten(input_)
weight = tf.get_variable(
'weight_{}'.format(index + 1), (int(input_.shape[1]), size),
tf.float32, weight_initializer)
output = tf.matmul(input_, weight)
if normalizer:
output = normalizer(output)
outputs.append(output)
output = sum(outputs)
# Add bias.
bias = tf.get_variable(
'weight', (size,), tf.float32, bias_initializer)
output += bias
# Activation function.
if activation:
output = activation(output)
# Restore shape dimensions that are consistent among inputs.
dim = 0
while dim < min(len(shape) for shape in shapes):
none = shapes[0].as_list()[dim]
equal = all(shape[dim] == shapes[0][dim] for shape in shapes)
if none or not equal:
break
dim += 1
shape = output.shape.as_list()[:1] + shapes[0][:dim] + [-1]
output = tf.reshape(output, shape)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment