Skip to content

Instantly share code, notes, and snippets.

@jostmey
Last active September 8, 2018 21:05
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jostmey/08e7dd67676f14a06b942ca5e8082360 to your computer and use it in GitHub Desktop.
Save jostmey/08e7dd67676f14a06b942ca5e8082360 to your computer and use it in GitHub Desktop.
Recurrent Weighted Average (RWA) model as a Tensorflow RNNCell
##########################################################################################
# Author: Jared L. Ostmeyer
# Date Started: 2017-04-11
# Purpose: Recurrent weighted average cell for tensorflow.
##########################################################################################
"""Implementation of recurrent weighted average cell as a TensorFlow module. See
https://arxiv.org/abs/1703.01253 for a mathematical description of the model.
"""
import tensorflow as tf
class RWACell(tf.nn.rnn_cell.RNNCell):
"""Recurrent weighted averge cell (https://arxiv.org/abs/1703.01253)"""
def __init__(self, num_units):
"""Initialize the RWA cell.
Args:
num_units: int, The number of units in the RWA cell.
"""
self.num_units = num_units
self.activation = tf.tanh
def zero_state(self, batch_size, dtype):
"""`zero_state` is overridden to return non-zero values and
parameters that must be learned."""
num_units = self.num_units
activation = self.activation
n = tf.zeros([batch_size, num_units], dtype=dtype)
d = tf.zeros([batch_size, num_units], dtype=dtype)
h = tf.zeros([batch_size, num_units], dtype=dtype)
a_max = -float('inf')*tf.ones([batch_size, num_units], dtype=dtype) # Start off with a large negative number with room for this value to decay
s_0 = tf.get_variable(
's_0', [num_units],
initializer=tf.random_normal_initializer(stddev=1.0, dtype=dtype),
dtype=dtype
)
h += activation(tf.expand_dims(s_0, 0))
return (n, d, h, a_max)
def __call__(self, inputs, state, scope=None):
num_inputs = inputs.get_shape()[1]
num_units = self.num_units
activation = self.activation
x = inputs
n, d, h, a_max = state
if scope is not None:
raise ValueError(
"The argument `scope` for `RWACell.__call__` is broken and "
"no longer works. The scope is hard-coded to make the initial "
"state learnable. See `s0` in `RWACell.zero_state`."
)
W_u = tf.get_variable(
'W_u', [num_inputs, num_units],
initializer=tf.variance_scaling_initializer(scale=1.0, mode='fan_avg', distribution='uniform', dtype=h.dtype),
dtype=h.dtype
)
b_u = tf.get_variable(
'b_u', [num_units],
initializer=tf.constant_initializer(0.0), dtype=h.dtype
)
W_g = tf.get_variable(
'W_g', [num_inputs+num_units, num_units],
initializer=tf.variance_scaling_initializer(scale=1.0, mode='fan_avg', distribution='uniform', dtype=h.dtype),
dtype=h.dtype
)
b_g = tf.get_variable(
'b_g', [num_units],
initializer=tf.constant_initializer(0.0), dtype=h.dtype
)
W_a = tf.get_variable(
'W_a', [num_inputs+num_units, num_units],
initializer=tf.variance_scaling_initializer(scale=1.0, mode='fan_avg', distribution='uniform', dtype=h.dtype),
dtype=h.dtype
)
xh = tf.concat([x, h], 1)
u = tf.matmul(x, W_u)+b_u
g = tf.matmul(xh, W_g)+b_g
a = tf.matmul(xh, W_a) # The bias term when factored out of the numerator and denominator cancels and is unnecessary
z = tf.multiply(u, tf.nn.tanh(g))
a_newmax = tf.maximum(a_max, a)
exp_diff = tf.exp(a_max-a_newmax)
exp_scaled = tf.exp(a-a_newmax)
n = tf.multiply(n, exp_diff)+tf.multiply(z, exp_scaled) # Numerically stable update of numerator
d = tf.multiply(d, exp_diff)+exp_scaled # Numerically stable update of denominator
h = activation(tf.div(n, d))
a_max = a_newmax
return h, (n, d, h, a_max)
@property
def output_size(self):
return self.num_units
@property
def state_size(self):
return (self.num_units, self.num_units, self.num_units, self.num_units)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment