Skip to content

Instantly share code, notes, and snippets.

@aidangomez
Created March 1, 2016 20:48
Show Gist options
  • Save aidangomez/7657ff6c053d0ccc2f4d to your computer and use it in GitHub Desktop.
Save aidangomez/7657ff6c053d0ccc2f4d to your computer and use it in GitHub Desktop.
A modified version of tensorflow/python/ops/rnn,py
def cw_rnn(cells, strides, inputs, initial_states=None, dtype=None,
sequence_length=None, scope=None):
"""Creates a recurrent neural network specified by RNNCell "cell".
Args:
cells: C instances of RNNCells.
inputs: A length C list of lists, each with length[c] = T // strides[c], containing tensors of shape
[batch_size, cell.input_size].
strides: C ints.
initial_states: (optional) An initial state for the RNN. This must be
a tensor of appropriate type and shape [batch_size x cell.state_size].
dtype: (optional) The data type for the initial state. Required if
initial_state is not provided.
sequence_length: Specifies the length of each sequence in inputs.
An int32 or int64 vector (tensor) size [batch_size]. Values in [0, T).
scope: VariableScope for the created subgraph; defaults to "RNN".
Returns:
A pair (outputs, state) where:
outputs is a length T list of outputs (one for each input)
state is the final state
Raises:
TypeError: If "cell" is not an instance of RNNCell.
ValueError: If inputs is None or an empty list.
"""
if not isinstance(cells, list):
raise TypeError("cells must be a list")
if not isinstance(cells[0], rnn_cell.RNNCell):
raise TypeError("cell must be an instance of RNNCell")
if not isinstance(inputs, list):
raise TypeError("inputs must be a tuple")
if not isinstance(inputs[0], list):
raise TypeError("inputs elements must be a list")
if not inputs:
raise ValueError("inputs must not be empty")
outputs = []
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "RNN") as varscope:
# if varscope.caching_device is None:
# varscope.set_caching_device(lambda op: op.device)
fixed_batch_size = inputs[0][0].get_shape().with_rank_at_least(1)[0]
if fixed_batch_size.value:
batch_size = fixed_batch_size.value
else:
batch_size = array_ops.shape(inputs[0])[0]
if initial_states is not None:
states = initial_states
else:
if not dtype:
raise ValueError("If no initial_state is provided, dtype must be.")
states = [cell.zero_state(batch_size, dtype) for cell in cells]
previous_cell_output = [array_ops.zeros(array_ops.pack([batch_size, cell.output_size]), inputs[0][0].dtype) for cell in cells]
(previous_cell_output[c].set_shape(tensor_shape.TensorShape([fixed_batch_size.value, cell.output_size])) for c, cell in enumerate(cells))
if sequence_length is not None:
sequence_length = math_ops.to_int32(sequence_length)
if sequence_length: # Prepare variables
zero_outputs = [array_ops.zeros(array_ops.pack([batch_size, cell.output_size]), inputs[0][0].dtype) for cell in cells]
(zero_outputs[c].set_shape(tensor_shape.TensorShape([fixed_batch_size.value, cell.output_size])) for c, cell in enumerate(cells))
min_sequence_length = math_ops.reduce_min(sequence_length)
max_sequence_length = math_ops.reduce_max(sequence_length)
for time in xrange(len(inputs[0])):
if time > 0: vs.get_variable_scope().reuse_variables()
cell_calls = []
for c, cell in enumerate(cells):
if time % strides[c] == 0:
cell_calls.append(cell(inputs[c][time // strides[c]], states[c], scope=str(c)))
else:
cell_calls.append((previous_cell_output[c], states[c]))
if sequence_length:
def output_state():
temp_outs = []
for c, call_cell in enumerate(cell_calls):
(output, state) = call_cell
previous_cell_output[c] = output
temp_outs.append(output)
output = tf.concat(1, temp_outs)
return output, state
zero_output_state = (
array_ops.zeros(array_ops.pack([batch_size, reduce(lambda x,y: x+y.output_size, cells, 0)]),
inputs[0][0].dtype),
array_ops.zeros(array_ops.pack([batch_size, cells[0].state_size]),
states[0].dtype))
(output, state) = control_flow_ops.cond(
time >= max_sequence_length,
lambda: zero_output_state, output_state)
output.set_shape([batch_size, reduce(lambda x,y: x+y.output_size, cells, 0)])
outputs.append(output)
return (outputs, state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment