Skip to content

Instantly share code, notes, and snippets.

@ChuaCheowHuan
Last active April 20, 2020 06:39
Show Gist options
  • Save ChuaCheowHuan/c1e2136cb06b3ddb56c4039ba904947f to your computer and use it in GitHub Desktop.
Save ChuaCheowHuan/c1e2136cb06b3ddb56c4039ba904947f to your computer and use it in GitHub Desktop.
Fast-slow LSTM
#class FSRNNCell(tf.contrib.rnn.RNNCell):
class FSRNNCell(tf.compat.v1.nn.rnn_cell.RNNCell):
def __init__(self, fast_cells, slow_cell, input_keep_prob=1.0, keep_prob=1.0, training=True):
"""Initialize the basic Fast-Slow RNN.
Args:
fast_cells: A list of RNN cells that will be used for the fast RNN.
The cells must be callable, implement zero_state() and all have the
same hidden size, like for example tf.contrib.rnn.BasicLSTMCell.
slow_cell: A single RNN cell for the slow RNN.
keep_prob: Keep probability for the non recurrent dropout. Any kind of
recurrent dropout should be implemented in the RNN cells.
training: If False, no dropout is applied.
"""
self.fast_layers = len(fast_cells)
assert self.fast_layers >= 2, 'At least two fast layers are needed'
self.fast_cells = fast_cells
self.slow_cell = slow_cell
self.keep_prob = keep_prob
self.input_keep_prob = input_keep_prob
if not training: self.keep_prob = 1.0
def __call__(self, inputs, state, scope='FS-RNN'):
F_state = state[0]
S_state = state[1]
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
inputs = tf.nn.dropout(inputs, self.input_keep_prob)
with tf.variable_scope('Fast_0'):
#F_output, F_state = self.fast_cells[0](inputs, F_state)
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[0], inputs=inputs, initial_state=F_state, time_major=True)
F_output_drop = tf.nn.dropout(F_output, self.keep_prob)
with tf.variable_scope('Slow'):
#S_output, S_state = self.slow_cell(F_output_drop, S_state)
S_output, S_state = tf.nn.dynamic_rnn(cell=self.slow_cell, inputs=F_output_drop, initial_state=S_state, time_major=True)
S_output_drop = tf.nn.dropout(S_output, self.keep_prob)
with tf.variable_scope('Fast_1'):
#F_output, F_state = self.fast_cells[1](S_output_drop, F_state)
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[1], inputs=S_output_drop, initial_state=F_state, time_major=True)
for i in range(2, self.fast_layers):
with tf.variable_scope('Fast_' + str(i)):
# Input cannot be empty for many RNN cells
#F_output, F_state = self.fast_cells[i](F_output[:, 0:1] * 0.0, F_state)
#F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[i], inputs=F_output[:, 0:1] * 0.0, initial_state=F_state, time_major=True)
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[i], inputs=F_output, initial_state=F_state, time_major=True)
F_output_drop = tf.nn.dropout(F_output, self.keep_prob)
return F_output_drop, (F_state, S_state)
def zero_state(self, batch_size, dtype):
F_state = self.fast_cells[0].zero_state(batch_size, dtype)
S_state = self.slow_cell.zero_state(batch_size, dtype)
return (F_state, S_state)
#Create one Slow and three Fast cells
slow = tf.contrib.rnn.BasicLSTMCell(32) # size_cell
fast = [tf.contrib.rnn.BasicLSTMCell(32),
tf.contrib.rnn.BasicLSTMCell(32),
tf.contrib.rnn.BasicLSTMCell(32)]
#Create a single FS-RNN using the cells
fs_lstm = FSRNNCell(fast, slow)
#Get initial state and create tf op to run one timestep
init_state = fs_lstm.zero_state(1, tf.float32) # batch_size
output, final_state = fs_lstm(np.zeros((12, 1, 11), np.float32), init_state) # (batch_size, state_size)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output).shape) # (12,32) = (batch_size, cell_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment