Created
January 18, 2017 10:14
-
-
Save yohokuno/3eebbf8b0d38e1d27902b585399da0f2 to your computer and use it in GitHub Desktop.
Residual RNN cell for TensorFlow 0.10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.python.ops.rnn_cell import RNNCell | |
from tensorflow.python.ops import variable_scope as vs | |
from tensorflow.python.util import nest | |
class ResidualRNNCell(RNNCell): | |
"""RNN cell composed sequentially of multiple simple cells with residual connection.""" | |
def __init__(self, cells, state_is_tuple=False): | |
"""Create a RNN cell composed sequentially of a number of RNNCells. | |
Args: | |
cells: list of RNNCells that will be composed in this order. | |
state_is_tuple: If True, accepted and returned states are n-tuples, where | |
`n = len(cells)`. By default (False), the states are all | |
concatenated along the column axis. | |
Raises: | |
ValueError: if cells is empty (not allowed), or at least one of the cells | |
returns a state tuple but the flag `state_is_tuple` is `False`. | |
""" | |
if not cells: | |
raise ValueError("Must specify at least one cell for MultiRNNCell.") | |
self._cells = cells | |
self._state_is_tuple = state_is_tuple | |
if not state_is_tuple: | |
if any(nest.is_sequence(c.state_size) for c in self._cells): | |
raise ValueError("Some cells return tuples of states, but the flag " | |
"state_is_tuple is not set. State sizes are: %s" | |
% str([c.state_size for c in self._cells])) | |
@property | |
def state_size(self): | |
if self._state_is_tuple: | |
return tuple(cell.state_size for cell in self._cells) | |
else: | |
return sum([cell.state_size for cell in self._cells]) | |
@property | |
def output_size(self): | |
return self._cells[-1].output_size | |
def __call__(self, inputs, state, scope=None): | |
"""Run this multi-layer cell on inputs, starting from state.""" | |
with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell" | |
cur_state_pos = 0 | |
cur_inp = inputs | |
new_states = [] | |
for i, cell in enumerate(self._cells): | |
with vs.variable_scope("Cell%d" % i): | |
if self._state_is_tuple: | |
if not nest.is_sequence(state): | |
raise ValueError( | |
"Expected state to be a tuple of length %d, but received: %s" | |
% (len(self.state_size), state)) | |
cur_state = state[i] | |
else: | |
cur_state = tf.slice( | |
state, [0, cur_state_pos], [-1, cell.state_size]) | |
cur_state_pos += cell.state_size | |
new_inp, new_state = cell(cur_inp, cur_state) | |
# residual network | |
cur_inp += new_inp | |
new_states.append(new_state) | |
new_states = (tuple(new_states) if self._state_is_tuple | |
else tf.concat(1, new_states)) | |
return cur_inp, new_states |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I ran an experiment in language modelling with LSTM using PennTree bank following Zaremba 2014's setting.
The result shows that residual connection does not help generalization in this setting. However, it is interesting to see residual LSTM seems to overfit than stacked LSTM despite exactly same number of parameters.
The GNMT paper states residual LSTMs help scale better than stacked LSTMs with over 4 layers. This might mean we need larger data and network to make difference between them.
Zaremba 2014: https://arxiv.org/abs/1409.2329
GNMT: https://arxiv.org/abs/1609.08144