Last active
June 17, 2019 12:10
-
-
Save danielwatson6/a5b72b5c2c21745efc20c78077f3ae95 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Automatically build a vanilla or Cudnn RNN.""" | |
import numpy as np | |
import tensorflow as tf | |
def gru(inputs, | |
num_layers, | |
num_units, | |
direction='unidirectional', | |
input_mode='linear_input', | |
dropout=0.0, | |
dtype=tf.float32, | |
initial_state=None, | |
sequence_length=None, | |
use_gpu_if_available=True, | |
name='gru'): | |
"""Op-autoselecting GRU.""" | |
with tf.variable_scope(name): | |
# Check if CUDA-capable GPU is available. | |
if use_gpu_if_available and tf.test.is_gpu_available(cuda_only=True): | |
output, states = tf.contrib.cudnn_rnn.CudnnGRU( | |
num_layers, num_units, direction=direction, input_mode=input_mode, | |
dropout=dropout, dtype=dtype)(inputs, initial_state=initial_state) | |
# TODO: make this general enough to support LSTM (the only RNN op that | |
# has two elements in the state tuple). | |
states = states[0] | |
if direction == 'unidirectional': | |
states = tf.unstack(states, num=num_layers) | |
return output, tuple(states) | |
fw_states = [] | |
bw_states = [] | |
for i, state in enumerate(tf.unstack(states, num=2 * num_layers)): | |
if i % 2 == 0: | |
fw_states.append(state) | |
else: | |
bw_states.append(state) | |
return output, tuple(fw_states), tuple(bw_states) | |
# As of TensorFlow 1.9.0, there is a bug where the outermost variable scope | |
# is not set by the Cudnn-compatible RNN cells to match that of the Cudnn | |
# RNN ops for seamless saving/restoring. | |
with tf.variable_scope('cudnn_gru'): | |
def cell(): | |
return tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell(num_units) | |
# Unidirectional RNNs must be wrapped in MultiRNNCell; even 1 layer. | |
if direction == 'unidirectional': | |
cell = tf.nn.rnn_cell.MultiRNNCell([cell() for _ in range(num_layers)]) | |
return tf.nn.dynamic_rnn( | |
cell, inputs, initial_state=initial_state, dtype=dtype, | |
sequence_length=sequence_length, time_major=True) | |
fw_cells = [cell() for _ in range(num_layers)] | |
bw_cells = [cell() for _ in range(num_layers)] | |
initial_states_fw = None | |
initial_states_bw = None | |
if initial_state is not None: | |
initial_states_fw, initial_states_bw = initial_state | |
output = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( | |
fw_cells, bw_cells, inputs, initial_states_fw=initial_states_fw, | |
initial_states_bw=initial_states_bw, dtype=dtype, | |
sequence_length=sequence_length, time_major=True) | |
return output | |
# NOTE: the code below is just for testing; no need to copy. | |
import os | |
import tempfile | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
def _test_unidirectional_gru(): | |
print("Testing unidirectional GRU") | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
fixed_inputs = np.random.randn(5, 3, 7) | |
# Test Cudnn functionality | |
graph = tf.Graph() | |
with graph.as_default(): | |
inputs = tf.placeholder(tf.float32, shape=[5, 3, 7]) | |
gru_op = gru(inputs, 2, 11) | |
with tf.Session(graph=graph) as sess: | |
sess.run(tf.global_variables_initializer()) | |
feed_dict = {inputs: fixed_inputs} | |
cudnn_output, cudnn_states = sess.run(gru_op, feed_dict=feed_dict) | |
tf.train.Saver().save(sess, tmpdirname) | |
# Test vanilla portability and functionality | |
graph = tf.Graph() | |
with graph.as_default(): | |
inputs = tf.placeholder(tf.float32, shape=[5, 3, 7]) | |
gru_op = gru(inputs, 2, 11, use_gpu_if_available=False) | |
with tf.Session(graph=graph) as sess: | |
saver = tf.train.Saver().restore(sess, tmpdirname) | |
feed_dict = {inputs: fixed_inputs} | |
vanilla_output, vanilla_states = sess.run(gru_op, feed_dict=feed_dict) | |
# Test equivalence | |
print("Output MSE:", ((cudnn_output - vanilla_output) ** 2).mean()) | |
print("States MSE:") | |
for cudnn_s, vanilla_s in zip(cudnn_states, vanilla_states): | |
print(" ", ((cudnn_s - vanilla_s) ** 2).mean()) | |
def _test_bidirectional_gru(): | |
print("Testing bidirectional GRU") | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
fixed_inputs = np.random.randn(5, 3, 7) | |
# Test Cudnn functionality | |
graph = tf.Graph() | |
with graph.as_default(): | |
inputs = tf.placeholder(tf.float32, shape=[5, 3, 7]) | |
gru_op = gru(inputs, 2, 11, direction='bidirectional') | |
with tf.Session(graph=graph) as sess: | |
sess.run(tf.global_variables_initializer()) | |
feed_dict = {inputs: fixed_inputs} | |
cudnn_output, cudnn_fw, cudnn_bw = sess.run(gru_op, feed_dict=feed_dict) | |
tf.train.Saver().save(sess, tmpdirname) | |
# Test vanilla portability and functionality | |
graph = tf.Graph() | |
with graph.as_default(): | |
inputs = tf.placeholder(tf.float32, shape=[5, 3, 7]) | |
gru_op = gru( | |
inputs, 2, 11, direction='bidirectional', use_gpu_if_available=False) | |
with tf.Session(graph=graph) as sess: | |
saver = tf.train.Saver().restore(sess, tmpdirname) | |
feed_dict = {inputs: fixed_inputs} | |
vanilla_output, vanilla_fw, vanilla_bw = sess.run( | |
gru_op, feed_dict=feed_dict) | |
# Test equivalence | |
print("Output MSE:", ((cudnn_output - vanilla_output) ** 2).mean()) | |
print("FW States MSE:") | |
for cudnn_s, vanilla_s in zip(cudnn_fw, vanilla_fw): | |
print(" ", ((cudnn_s - vanilla_s) ** 2).mean()) | |
print("BW States MSE:") | |
for cudnn_s, vanilla_s in zip(cudnn_bw, vanilla_bw): | |
print(" ", ((cudnn_s - vanilla_s) ** 2).mean()) | |
if __name__ == '__main__': | |
_test_unidirectional_gru() | |
_test_bidirectional_gru() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment