Batch normalized LSTM Cell for Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state""" | |
import tensorflow as tf, numpy as np | |
RNNCell = tf.nn.rnn_cell.RNNCell | |
class BNLSTMCell(RNNCell): | |
'''Batch normalized LSTM as described in arxiv.org/abs/1603.09025''' | |
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95): | |
""" | |
* max bn steps is the maximum number of steps for which to store separate population stats | |
""" | |
self._num_units = num_units | |
self._training = is_training_tensor | |
self._max_bn_steps = max_bn_steps | |
self._activation = activation | |
self._decay = decay | |
self._initial_scale = 0.1 | |
@property | |
def state_size(self): | |
return (self._num_units, self._num_units, 1) | |
@property | |
def output_size(self): | |
return self._num_units | |
def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False): | |
'''Assume 2d [batch, values] tensor''' | |
with tf.variable_scope(name_scope): | |
size = x.get_shape().as_list()[1] | |
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale)) | |
if no_offset: | |
offset = 0 | |
elif set_forget_gate_bias: | |
offset = tf.get_variable('offset', [size], initializer=offset_initializer()) | |
else: | |
offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer) | |
pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False) | |
pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False) | |
step = tf.minimum(step, self._max_bn_steps - 1) | |
pop_mean = pop_mean_all_steps[step] | |
pop_var = pop_var_all_steps[step] | |
batch_mean, batch_var = tf.nn.moments(x, [0]) | |
def batch_statistics(): | |
pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay) | |
pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay) | |
with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]): | |
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon) | |
def population_statistics(): | |
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon) | |
return tf.cond(self._training, batch_statistics, population_statistics) | |
def __call__(self, x, state, scope=None): | |
with tf.variable_scope(scope or type(self).__name__): | |
c, h, step = state | |
_step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0)) | |
x_size = x.get_shape().as_list()[1] | |
W_xh = tf.get_variable('W_xh', | |
[x_size, 4 * self._num_units], | |
initializer=orthogonal_lstm_initializer()) | |
W_hh = tf.get_variable('W_hh', | |
[self._num_units, 4 * self._num_units], | |
initializer=orthogonal_lstm_initializer()) | |
hh = tf.matmul(h, W_hh) | |
xh = tf.matmul(x, W_xh) | |
bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True) | |
bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True) | |
hidden = bn_xh + bn_hh | |
f, i, o, j = tf.split(1, 4, hidden) | |
new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j) | |
bn_new_c = self._batch_norm(new_c, 'c', _step) | |
new_h = self._activation(bn_new_c) * tf.sigmoid(o) | |
return new_h, (new_c, new_h, step+1) | |
def orthogonal_lstm_initializer(): | |
def orthogonal(shape, dtype=tf.float32, partition_info=None): | |
# taken from https://github.com/cooijmanstim/recurrent-batch-normalization | |
# taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a | |
""" benanne lasagne ortho init (faster than qr approach)""" | |
flat_shape = (shape[0], np.prod(shape[1:])) | |
a = np.random.normal(0.0, 1.0, flat_shape) | |
u, _, v = np.linalg.svd(a, full_matrices=False) | |
q = u if u.shape == flat_shape else v # pick the one with the correct shape | |
q = q.reshape(shape) | |
return tf.constant(q[:shape[0], :shape[1]], dtype) | |
return orthogonal | |
def offset_initializer(): | |
def _initializer(shape, dtype=tf.float32, partition_info=None): | |
size = shape[0] | |
assert size % 4 == 0 | |
size = size // 4 | |
res = [np.ones((size)), np.zeros((size*3))] | |
return tf.constant(np.concatenate(res, axis=0), dtype) | |
return _initializer |
@batzner Have you tried to test this cell's performance? I have tried to implement a similar one, but the performance is so bad in training data.
code like this:
def bn(name,
inputs,
trainable=True,
training=True,
reuse=tf.AUTO_REUSE):
return tf.layers.batch_normalization(inputs,
trainable=trainable,
training=training,
name=name,
reuse=reuse)
class BNLSTMCell(RNNCell):
'''
Applying batch normalization in LSTM cell
'''
def __init__(self,
num_units,
training,
gate_initializer=tf.initializers.orthogonal(),
bias_initializer=tf.initializers.ones()):
self.num_timestamp = 0
self.num_units = num_units
self.training = training
self.gate_initializer = gate_initializer
self.bias_initializer = bias_initializer
@property
def state_size(self):
'''
:return: (cell_size, hidden_size)
'''
return (self.num_units, self.num_units)
@property
def output_size(self):
'''
:return: output_size
'''
return self.num_units
def __call__(self, inputs, state, scope=None):
'''
:param inputs: 2d tensor with shape [batch_size, feature_size].
:param state: Because self.state_size is tuple, so state is tuple with shape
[batch_size, s] for s in self.state_size.
:param scope: scope for create subgraph, default value is class name.
:return: output: 2d tensor with shape [batch_size, self.output_size]
new_state: Next state, [cell_state, hidden_state]
'''
if scope is None:
scope = type(self).__name__
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
cell_state, hidden_state = state
feature_size = inputs.get_shape().as_list()[1]
W_XH = tf.get_variable('W_XH', shape=[feature_size, 4 * self.num_units],\
initializer=self.gate_initializer)
W_HH = tf.get_variable('W_HH', shape=[self.num_units, 4 * self.num_units],\
initializer=self.gate_initializer)
bias = tf.get_variable('bias', shape=[4 * self.num_units], \
initializer=self.bias_initializer)
hidden = bn('bn_xWXH' + str(self.num_timestamp), tf.matmul(inputs, W_XH), training=self.training) + \
bn('bn_hWHH' + str(self.num_timestamp), tf.matmul(hidden_state, W_HH), training=self.training) + \
bias
#hidden = tf.matmul(inputs, W_XH) + tf.matmul(hidden_state, W_HH) + bias
forget, input, output, candidate_cell_state = tf.split(hidden, num_or_size_splits=4, axis=1)
next_cell_state = tf.math.sigmoid(forget) * cell_state + \
tf.math.sigmoid(input) * tf.math.tanh(candidate_cell_state)
bn_next_cell_state = bn('bn_next_cell_state' + str(self.num_timestamp), next_cell_state, training=self.training)
#bn_next_cell_state = next_cell_state
next_hidden_state = tf.math.sigmoid(output) * tf.math.tanh(bn_next_cell_state)
self.num_timestamp = self.num_timestamp + 1
print('training in __call__ function: {}'.format(self.training))
return next_hidden_state, (next_cell_state, next_hidden_state)
I have tried this method. but using population statistics (training: False) at test time gives worse results than batch statistics.
I think bn-lstm doesn't work.
@spitis Maybe I have made some stupid mistake.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you! I'm trying to figure out how to integrate Batch Normalization into a regular RNN cell and this was really helpful.