Created
March 1, 2016 20:48
-
-
Save aidangomez/7657ff6c053d0ccc2f4d to your computer and use it in GitHub Desktop.
A modified version of tensorflow/python/ops/rnn,py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def cw_rnn(cells, strides, inputs, initial_states=None, dtype=None, | |
sequence_length=None, scope=None): | |
"""Creates a recurrent neural network specified by RNNCell "cell". | |
Args: | |
cells: C instances of RNNCells. | |
inputs: A length C list of lists, each with length[c] = T // strides[c], containing tensors of shape | |
[batch_size, cell.input_size]. | |
strides: C ints. | |
initial_states: (optional) An initial state for the RNN. This must be | |
a tensor of appropriate type and shape [batch_size x cell.state_size]. | |
dtype: (optional) The data type for the initial state. Required if | |
initial_state is not provided. | |
sequence_length: Specifies the length of each sequence in inputs. | |
An int32 or int64 vector (tensor) size [batch_size]. Values in [0, T). | |
scope: VariableScope for the created subgraph; defaults to "RNN". | |
Returns: | |
A pair (outputs, state) where: | |
outputs is a length T list of outputs (one for each input) | |
state is the final state | |
Raises: | |
TypeError: If "cell" is not an instance of RNNCell. | |
ValueError: If inputs is None or an empty list. | |
""" | |
if not isinstance(cells, list): | |
raise TypeError("cells must be a list") | |
if not isinstance(cells[0], rnn_cell.RNNCell): | |
raise TypeError("cell must be an instance of RNNCell") | |
if not isinstance(inputs, list): | |
raise TypeError("inputs must be a tuple") | |
if not isinstance(inputs[0], list): | |
raise TypeError("inputs elements must be a list") | |
if not inputs: | |
raise ValueError("inputs must not be empty") | |
outputs = [] | |
# Create a new scope in which the caching device is either | |
# determined by the parent scope, or is set to place the cached | |
# Variable using the same placement as for the rest of the RNN. | |
with vs.variable_scope(scope or "RNN") as varscope: | |
# if varscope.caching_device is None: | |
# varscope.set_caching_device(lambda op: op.device) | |
fixed_batch_size = inputs[0][0].get_shape().with_rank_at_least(1)[0] | |
if fixed_batch_size.value: | |
batch_size = fixed_batch_size.value | |
else: | |
batch_size = array_ops.shape(inputs[0])[0] | |
if initial_states is not None: | |
states = initial_states | |
else: | |
if not dtype: | |
raise ValueError("If no initial_state is provided, dtype must be.") | |
states = [cell.zero_state(batch_size, dtype) for cell in cells] | |
previous_cell_output = [array_ops.zeros(array_ops.pack([batch_size, cell.output_size]), inputs[0][0].dtype) for cell in cells] | |
(previous_cell_output[c].set_shape(tensor_shape.TensorShape([fixed_batch_size.value, cell.output_size])) for c, cell in enumerate(cells)) | |
if sequence_length is not None: | |
sequence_length = math_ops.to_int32(sequence_length) | |
if sequence_length: # Prepare variables | |
zero_outputs = [array_ops.zeros(array_ops.pack([batch_size, cell.output_size]), inputs[0][0].dtype) for cell in cells] | |
(zero_outputs[c].set_shape(tensor_shape.TensorShape([fixed_batch_size.value, cell.output_size])) for c, cell in enumerate(cells)) | |
min_sequence_length = math_ops.reduce_min(sequence_length) | |
max_sequence_length = math_ops.reduce_max(sequence_length) | |
for time in xrange(len(inputs[0])): | |
if time > 0: vs.get_variable_scope().reuse_variables() | |
cell_calls = [] | |
for c, cell in enumerate(cells): | |
if time % strides[c] == 0: | |
cell_calls.append(cell(inputs[c][time // strides[c]], states[c], scope=str(c))) | |
else: | |
cell_calls.append((previous_cell_output[c], states[c])) | |
if sequence_length: | |
def output_state(): | |
temp_outs = [] | |
for c, call_cell in enumerate(cell_calls): | |
(output, state) = call_cell | |
previous_cell_output[c] = output | |
temp_outs.append(output) | |
output = tf.concat(1, temp_outs) | |
return output, state | |
zero_output_state = ( | |
array_ops.zeros(array_ops.pack([batch_size, reduce(lambda x,y: x+y.output_size, cells, 0)]), | |
inputs[0][0].dtype), | |
array_ops.zeros(array_ops.pack([batch_size, cells[0].state_size]), | |
states[0].dtype)) | |
(output, state) = control_flow_ops.cond( | |
time >= max_sequence_length, | |
lambda: zero_output_state, output_state) | |
output.set_shape([batch_size, reduce(lambda x,y: x+y.output_size, cells, 0)]) | |
outputs.append(output) | |
return (outputs, state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment