Skip to content

Instantly share code, notes, and snippets.

@ChuaCheowHuan
Last active April 20, 2020 08:12
Show Gist options
  • Save ChuaCheowHuan/8a3ab68d114b3b1a2689a7cac9bf65e7 to your computer and use it in GitHub Desktop.
Save ChuaCheowHuan/8a3ab68d114b3b1a2689a7cac9bf65e7 to your computer and use it in GitHub Desktop.
Fast-slow LSTM with variational unit (VU)
import tensorflow as tf
import numpy as np
latent_dim = 2
#class FSRNNCell_VU(tf.contrib.rnn.RNNCell):
class FSRNNCell_VU(tf.compat.v1.nn.rnn_cell.RNNCell):
def __init__(self, fast_cells, slow_cell, input_keep_prob=1.0, keep_prob=1.0, training=True):
"""Initialize the basic Fast-Slow RNN.
Args:
fast_cells: A list of RNN cells that will be used for the fast RNN.
The cells must be callable, implement zero_state() and all have the
same hidden size, like for example tf.contrib.rnn.BasicLSTMCell.
slow_cell: A single RNN cell for the slow RNN.
keep_prob: Keep probability for the non recurrent dropout. Any kind of
recurrent dropout should be implemented in the RNN cells.
training: If False, no dropout is applied.
"""
self.fast_layers = len(fast_cells)
assert self.fast_layers >= 2, 'At least two fast layers are needed'
self.fast_cells = fast_cells
self.slow_cell = slow_cell
self.keep_prob = keep_prob
self.input_keep_prob = input_keep_prob
# VU
self.S_mean = None
self.S_sigma = None
#self.S_norm_args = None
self.F_mean = None
self.F_sigma = None
#self.F_norm_args = None
if not training: self.keep_prob = 1.0
def __call__(self, inputs, state, scope='FS-RNN'):
F_state = state[0]
S_state = state[1]
# VU
#a_w = tf.random_normal_initializer(seed=tf_operation_level_seed+10)
a_w = tf.random_normal_initializer(seed=10)
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
inputs = tf.nn.dropout(inputs, self.input_keep_prob)
with tf.variable_scope('Fast_0'):
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[0], inputs=inputs, initial_state=F_state, time_major=True)
F_output_drop = tf.nn.dropout(F_output, self.keep_prob)
with tf.variable_scope('Slow'):
S_output, S_state = tf.nn.dynamic_rnn(cell=self.slow_cell, inputs=F_output_drop, initial_state=S_state, time_major=True)
# VU
self.S_mean = tf.layers.dense(S_output, latent_dim, activation=None, kernel_initializer = a_w, name='mean', trainable=True)
self.S_sigma = tf.layers.dense(S_output, latent_dim, tf.nn.softplus, kernel_initializer = a_w, name='sigma', trainable=True)
#self.S_norm_args = tf.concat([S_mean, S_sigma], 0)
S_output_drop = tf.nn.dropout(S_output, self.keep_prob)
with tf.variable_scope('Fast_1'):
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[1], inputs=S_output_drop, initial_state=F_state, time_major=True)
for i in range(2, self.fast_layers):
with tf.variable_scope('Fast_' + str(i)):
# Input cannot be empty for many RNN cells
#F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[i], inputs=F_output[:, 0:1] * 0.0, initial_state=F_state, time_major=True)
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[i], inputs=F_output, initial_state=F_state, time_major=True)
# VU
self.F_mean = tf.layers.dense(F_output, latent_dim, activation=None, kernel_initializer = a_w, name='mean', trainable=True)
self.F_sigma = tf.layers.dense(F_output, latent_dim, tf.nn.softplus, kernel_initializer = a_w, name='sigma', trainable=True)
#self.F_norm_args = tf.concat([F_mean, F_sigma], 0)
F_output_drop = tf.nn.dropout(F_output, self.keep_prob)
#return F_output_drop, (F_state, S_state)
#return F_output_drop, (F_state, S_state), (self.S_norm_args, self.F_norm_args)
return F_output_drop, (F_state, S_state), (self.S_mean, self.S_sigma), (self.F_mean, self.F_sigma)
def zero_state(self, batch_size, dtype):
F_state = self.fast_cells[0].zero_state(batch_size, dtype)
S_state = self.slow_cell.zero_state(batch_size, dtype)
return (F_state, S_state)
batch_size = 3
cell_size = 5
state_size = 7
#Create one Slow and three Fast cells
slow = tf.contrib.rnn.BasicLSTMCell(cell_size) # cell_size
fast = [tf.contrib.rnn.BasicLSTMCell(cell_size),
tf.contrib.rnn.BasicLSTMCell(cell_size),
tf.contrib.rnn.BasicLSTMCell(cell_size)]
#Create a single FS-RNN using the cells
fs_lstm_vu = FSRNNCell_VU(fast, slow)
#Get initial state and create tf op to run one timestep
init_state = fs_lstm_vu.zero_state(1, tf.float32) # batch_size
#output, final_state, N_args = fs_lstm_vu(np.zeros((batch_size, 1, state_size), np.float32), init_state) # (batch_size, state_size)
output, final_state, S_N_args, F_N_args = fs_lstm_vu(np.zeros((batch_size, 1, state_size), np.float32), init_state) # (batch_size, state_size)
output, final_state, S_N_args, F_N_args = fs_lstm_vu(np.ones((batch_size, 1, state_size), np.float32), init_state) # (batch_size, state_size)
S_norm_dist = tf.distributions.Normal(loc=S_N_args[0], scale=S_N_args[1])
F_norm_dist = tf.distributions.Normal(loc=F_N_args[0], scale=F_N_args[1])
F_norm_dist_sample_z = tf.squeeze(F_norm_dist.sample(1), axis=0) # choosing action
KL = tf.distributions.kl_divergence(S_norm_dist, F_norm_dist) # to be added to loss function, try to minimize KL
with tf.Session() as sess:
for i in range(1):
sess.run(tf.global_variables_initializer())
#print('output', sess.run(output)) # (12,32) = (batch_size, cell_size)
#print(sess.run(output).shape) # (12,32) = (batch_size, cell_size)
#print(len(final_state))
#print(S_norm_dist)
#print(F_norm_dist)
#print(sess.run(F_norm_dist_sample_z))
print('KL', sess.run(KL))
print('S_N_args[0]', sess.run(S_N_args[0]))
print('S_N_args[1]', sess.run(S_N_args[1]))
print('F_N_args[0]', sess.run(F_N_args[0]))
print('F_N_args[1]', sess.run(F_N_args[1]))
assert (sess.run(S_N_args[0]) != sess.run(F_N_args[0])).all()
assert (sess.run(S_N_args[1]) != sess.run(F_N_args[1])).all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment