Last active
September 8, 2018 21:05
-
-
Save jostmey/08e7dd67676f14a06b942ca5e8082360 to your computer and use it in GitHub Desktop.
Recurrent Weighted Average (RWA) model as a Tensorflow RNNCell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
########################################################################################## | |
# Author: Jared L. Ostmeyer | |
# Date Started: 2017-04-11 | |
# Purpose: Recurrent weighted average cell for tensorflow. | |
########################################################################################## | |
"""Implementation of recurrent weighted average cell as a TensorFlow module. See | |
https://arxiv.org/abs/1703.01253 for a mathematical description of the model. | |
""" | |
import tensorflow as tf | |
class RWACell(tf.nn.rnn_cell.RNNCell): | |
"""Recurrent weighted averge cell (https://arxiv.org/abs/1703.01253)""" | |
def __init__(self, num_units): | |
"""Initialize the RWA cell. | |
Args: | |
num_units: int, The number of units in the RWA cell. | |
""" | |
self.num_units = num_units | |
self.activation = tf.tanh | |
def zero_state(self, batch_size, dtype): | |
"""`zero_state` is overridden to return non-zero values and | |
parameters that must be learned.""" | |
num_units = self.num_units | |
activation = self.activation | |
n = tf.zeros([batch_size, num_units], dtype=dtype) | |
d = tf.zeros([batch_size, num_units], dtype=dtype) | |
h = tf.zeros([batch_size, num_units], dtype=dtype) | |
a_max = -float('inf')*tf.ones([batch_size, num_units], dtype=dtype) # Start off with a large negative number with room for this value to decay | |
s_0 = tf.get_variable( | |
's_0', [num_units], | |
initializer=tf.random_normal_initializer(stddev=1.0, dtype=dtype), | |
dtype=dtype | |
) | |
h += activation(tf.expand_dims(s_0, 0)) | |
return (n, d, h, a_max) | |
def __call__(self, inputs, state, scope=None): | |
num_inputs = inputs.get_shape()[1] | |
num_units = self.num_units | |
activation = self.activation | |
x = inputs | |
n, d, h, a_max = state | |
if scope is not None: | |
raise ValueError( | |
"The argument `scope` for `RWACell.__call__` is broken and " | |
"no longer works. The scope is hard-coded to make the initial " | |
"state learnable. See `s0` in `RWACell.zero_state`." | |
) | |
W_u = tf.get_variable( | |
'W_u', [num_inputs, num_units], | |
initializer=tf.variance_scaling_initializer(scale=1.0, mode='fan_avg', distribution='uniform', dtype=h.dtype), | |
dtype=h.dtype | |
) | |
b_u = tf.get_variable( | |
'b_u', [num_units], | |
initializer=tf.constant_initializer(0.0), dtype=h.dtype | |
) | |
W_g = tf.get_variable( | |
'W_g', [num_inputs+num_units, num_units], | |
initializer=tf.variance_scaling_initializer(scale=1.0, mode='fan_avg', distribution='uniform', dtype=h.dtype), | |
dtype=h.dtype | |
) | |
b_g = tf.get_variable( | |
'b_g', [num_units], | |
initializer=tf.constant_initializer(0.0), dtype=h.dtype | |
) | |
W_a = tf.get_variable( | |
'W_a', [num_inputs+num_units, num_units], | |
initializer=tf.variance_scaling_initializer(scale=1.0, mode='fan_avg', distribution='uniform', dtype=h.dtype), | |
dtype=h.dtype | |
) | |
xh = tf.concat([x, h], 1) | |
u = tf.matmul(x, W_u)+b_u | |
g = tf.matmul(xh, W_g)+b_g | |
a = tf.matmul(xh, W_a) # The bias term when factored out of the numerator and denominator cancels and is unnecessary | |
z = tf.multiply(u, tf.nn.tanh(g)) | |
a_newmax = tf.maximum(a_max, a) | |
exp_diff = tf.exp(a_max-a_newmax) | |
exp_scaled = tf.exp(a-a_newmax) | |
n = tf.multiply(n, exp_diff)+tf.multiply(z, exp_scaled) # Numerically stable update of numerator | |
d = tf.multiply(d, exp_diff)+exp_scaled # Numerically stable update of denominator | |
h = activation(tf.div(n, d)) | |
a_max = a_newmax | |
return h, (n, d, h, a_max) | |
@property | |
def output_size(self): | |
return self.num_units | |
@property | |
def state_size(self): | |
return (self.num_units, self.num_units, self.num_units, self.num_units) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment