Skip to content

Instantly share code, notes, and snippets.

@yohokuno
Created January 18, 2017 10:14
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save yohokuno/3eebbf8b0d38e1d27902b585399da0f2 to your computer and use it in GitHub Desktop.
Save yohokuno/3eebbf8b0d38e1d27902b585399da0f2 to your computer and use it in GitHub Desktop.
Residual RNN cell for TensorFlow 0.10
import tensorflow as tf
from tensorflow.python.ops.rnn_cell import RNNCell
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
class ResidualRNNCell(RNNCell):
"""RNN cell composed sequentially of multiple simple cells with residual connection."""
def __init__(self, cells, state_is_tuple=False):
"""Create a RNN cell composed sequentially of a number of RNNCells.
Args:
cells: list of RNNCells that will be composed in this order.
state_is_tuple: If True, accepted and returned states are n-tuples, where
`n = len(cells)`. By default (False), the states are all
concatenated along the column axis.
Raises:
ValueError: if cells is empty (not allowed), or at least one of the cells
returns a state tuple but the flag `state_is_tuple` is `False`.
"""
if not cells:
raise ValueError("Must specify at least one cell for MultiRNNCell.")
self._cells = cells
self._state_is_tuple = state_is_tuple
if not state_is_tuple:
if any(nest.is_sequence(c.state_size) for c in self._cells):
raise ValueError("Some cells return tuples of states, but the flag "
"state_is_tuple is not set. State sizes are: %s"
% str([c.state_size for c in self._cells]))
@property
def state_size(self):
if self._state_is_tuple:
return tuple(cell.state_size for cell in self._cells)
else:
return sum([cell.state_size for cell in self._cells])
@property
def output_size(self):
return self._cells[-1].output_size
def __call__(self, inputs, state, scope=None):
"""Run this multi-layer cell on inputs, starting from state."""
with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell"
cur_state_pos = 0
cur_inp = inputs
new_states = []
for i, cell in enumerate(self._cells):
with vs.variable_scope("Cell%d" % i):
if self._state_is_tuple:
if not nest.is_sequence(state):
raise ValueError(
"Expected state to be a tuple of length %d, but received: %s"
% (len(self.state_size), state))
cur_state = state[i]
else:
cur_state = tf.slice(
state, [0, cur_state_pos], [-1, cell.state_size])
cur_state_pos += cell.state_size
new_inp, new_state = cell(cur_inp, cur_state)
# residual network
cur_inp += new_inp
new_states.append(new_state)
new_states = (tuple(new_states) if self._state_is_tuple
else tf.concat(1, new_states))
return cur_inp, new_states
@yohokuno
Copy link
Author

yohokuno commented Jan 18, 2017

I ran an experiment in language modelling with LSTM using PennTree bank following Zaremba 2014's setting.

config layers train valid test
medium stack 48.45 86.16 82.07
medium residual 30.24 87.69 83.70
large stack 55 37.87 82.62
large residual 26.11 93.71 82.62

The result shows that residual connection does not help generalization in this setting. However, it is interesting to see residual LSTM seems to overfit than stacked LSTM despite exactly same number of parameters.

The GNMT paper states residual LSTMs help scale better than stacked LSTMs with over 4 layers. This might mean we need larger data and network to make difference between them.

Zaremba 2014: https://arxiv.org/abs/1409.2329
GNMT: https://arxiv.org/abs/1609.08144

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment