Skip to content

Instantly share code, notes, and snippets.

@tam17aki
Last active May 31, 2019 23:15
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save tam17aki/f8bebcc427f99a3432592e5ca0186cb8 to your computer and use it in GitHub Desktop.
Save tam17aki/f8bebcc427f99a3432592e5ca0186cb8 to your computer and use it in GitHub Desktop.
An implementation of Delta RNN in TensorFlow
# Copyright (C) 2017 by Akira TAMAMORI
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright (C) 2017 by NickShahML
# URL : https://github.com/NickShahML/tensorflow_with_latest_papers
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Change :
# - 2017/04/30 fix source code to run on TensorFlow v0.12
# - add variable scope
# Notice :
# This file is tested on Tensorflow v0.12
import tensorflow as tf
from tensorflow.python.ops.rnn_cell import RNNCell
class DeltaRNNCell(RNNCell):
"""
Delta RNN.
Alexander G. Ororbia II, Tomas Mikolov and David Reitter,
"Learning Simpler Language Models with the
Delta Recurrent Neural Network Framework, "
https://arxiv.org/abs/1703.08864
"""
def __init__(self, num_units):
self._num_units = num_units
@property
def input_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_size(self):
return self._num_units
def _outer_function(self, inner_function_output,
past_hidden_state, activation=tf.nn.relu,
wx_parameterization_gate=True, scope=None):
"""Simulates Equation 3 in Delta RNN paper
r, the gate, can be parameterized in many different ways.
"""
assert inner_function_output.get_shape().as_list() == \
past_hidden_state.get_shape().as_list()
with tf.variable_scope(scope or type(self).__name__):
with tf.variable_scope("OuterFunction"):
r_bias = tf.get_variable(
"outer_function_gate",
[self._num_units],
dtype=tf.float32, initializer=tf.zeros_initializer)
# Equation 5 in Delta Rnn Paper
if wx_parameterization_gate:
r = self._W_x_inputs + r_bias
else:
r = r_bias
gate = tf.nn.sigmoid(r)
output = activation(
(1.0 - gate) * inner_function_output
+ gate * past_hidden_state)
return output
def _inner_function(self, inputs, past_hidden_state,
activation=tf.nn.tanh, scope=None):
"""second order function as described equation 11 in delta rnn paper
The main goal is to produce z_t of this function
"""
with tf.variable_scope(scope or type(self).__name__):
with tf.variable_scope("InnerFunction"):
with tf.variable_scope("Vh"):
V_h = _linear(past_hidden_state, self._num_units, True)
# We make this a private variable to be reused in the
# _outer_function
with tf.variable_scope("Wx"):
self._W_x_inputs = _linear(inputs, self._num_units, True)
alpha = tf.get_variable(
"alpha", [self._num_units], dtype=tf.float32,
initializer=tf.constant_initializer(1.0))
beta_one = tf.get_variable(
"beta_one", [self._num_units], dtype=tf.float32,
initializer=tf.constant_initializer(1.0))
beta_two = tf.get_variable(
"beta_two", [self._num_units], dtype=tf.float32,
initializer=tf.constant_initializer(1.0))
z_t_bias = tf.get_variable(
"z_t_bias", [self._num_units], dtype=tf.float32,
initializer=tf.constant_initializer(0.0))
# Second Order Cell Calculations
d_1_t = alpha * V_h * self._W_x_inputs
d_2_t = beta_one * V_h + beta_two * self._W_x_inputs
z_t = activation(d_1_t + d_2_t + z_t_bias)
return z_t
def __call__(self, inputs, state, scope=None):
inner_function_output = self._inner_function(inputs, state)
output = self._outer_function(inner_function_output, state)
# there is only one hidden state output to keep track of.
return output, output
def _linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
Args:
args: a 2D Tensor or a list of 2D, batch x n, Tensors.
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_start: starting value to initialize the bias; 0 by default.
scope: VariableScope for the created subgraph; defaults to "Linear".
Returns:
A 2D Tensor with shape [batch x output_size] equal to
sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
"""
if args is None or (isinstance(args, (list, tuple)) and not args):
raise ValueError("`args` must be specified")
if not isinstance(args, (list, tuple)):
args = [args]
# Calculate the total size of arguments on dimension 1.
total_arg_size = 0
shapes = [a.get_shape().as_list() for a in args]
for shape in shapes:
if len(shape) != 2:
raise ValueError(
"Linear is expecting 2D arguments: %s" % str(shapes))
if not shape[1]:
raise ValueError(
"Linear expects shape[1] of arguments: %s" % str(shapes))
else:
total_arg_size += shape[1]
# Now the computation.
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
if len(args) == 1:
res = tf.matmul(args[0], matrix)
else:
res = tf.matmul(tf.concat(1, args), matrix)
if not bias:
return res
bias_term = tf.get_variable(
"Bias", [output_size],
initializer=tf.constant_initializer(bias_start))
return res + bias_term
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment