Skip to content

Instantly share code, notes, and snippets.

@evanthebouncy
Created December 5, 2015 10:53
Show Gist options
  • Save evanthebouncy/4e71fb923697c6fcd7ef to your computer and use it in GitHub Desktop.
Save evanthebouncy/4e71fb923697c6fcd7ef to your computer and use it in GitHub Desktop.
trying to understand bidirectional rnn
batch_size = 1
input_size = 2
num_units = 3
input_length = 8
sess = tf.Session()
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)) as scope:
# our cells
cell_fw = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
input_size=input_size)
cell_bw = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
input_size=input_size)
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=123)
sequence_length = tf.placeholder(tf.int64)
cell_fw = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer)
cell_bw = tf.nn.rnn_cell.LSTMCell(
num_units, input_size, initializer=initializer)
# input_seq = tf.placeholder(tf.float32, [input_length, None, input_size])
input_seq = input_length * [tf.placeholder(tf.float32, shape=(batch_size, input_size))]
outputs = tf.nn.bidirectional_rnn(
cell_fw, cell_bw, input_seq, dtype=tf.float32,
sequence_length=sequence_length)
inputs_values = [np.random.randn(batch_size, input_size) for i in range(input_length)]
sess.run([tf.initialize_all_variables()])
feed_dict = {}
for i in range(input_length):
feed_dict[input_seq[i]] = inputs_values[i]
feed_dict[sequence_length] = [5]
res = sess.run([outputs], feed_dict=feed_dict)
print(res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment