Skip to content

Instantly share code, notes, and snippets.

Created February 2, 2017 03:05
What would you like to do?
Batch normalized LSTM Cell for Tensorflow
"""adapted from to store separate population statistics per state"""
import tensorflow as tf, numpy as np
RNNCell = tf.nn.rnn_cell.RNNCell
class BNLSTMCell(RNNCell):
'''Batch normalized LSTM as described in'''
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95):
* max bn steps is the maximum number of steps for which to store separate population stats
self._num_units = num_units
self._training = is_training_tensor
self._max_bn_steps = max_bn_steps
self._activation = activation
self._decay = decay
self._initial_scale = 0.1
def state_size(self):
return (self._num_units, self._num_units, 1)
def output_size(self):
return self._num_units
def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False):
'''Assume 2d [batch, values] tensor'''
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[1]
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale))
if no_offset:
offset = 0
elif set_forget_gate_bias:
offset = tf.get_variable('offset', [size], initializer=offset_initializer())
offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer)
pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False)
pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False)
step = tf.minimum(step, self._max_bn_steps - 1)
pop_mean = pop_mean_all_steps[step]
pop_var = pop_var_all_steps[step]
batch_mean, batch_var = tf.nn.moments(x, [0])
def batch_statistics():
pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay)
pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay)
with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(self._training, batch_statistics, population_statistics)
def __call__(self, x, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c, h, step = state
_step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0))
x_size = x.get_shape().as_list()[1]
W_xh = tf.get_variable('W_xh',
[x_size, 4 * self._num_units],
W_hh = tf.get_variable('W_hh',
[self._num_units, 4 * self._num_units],
hh = tf.matmul(h, W_hh)
xh = tf.matmul(x, W_xh)
bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True)
bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True)
hidden = bn_xh + bn_hh
f, i, o, j = tf.split(1, 4, hidden)
new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j)
bn_new_c = self._batch_norm(new_c, 'c', _step)
new_h = self._activation(bn_new_c) * tf.sigmoid(o)
return new_h, (new_c, new_h, step+1)
def orthogonal_lstm_initializer():
def orthogonal(shape, dtype=tf.float32, partition_info=None):
# taken from
# taken from
""" benanne lasagne ortho init (faster than qr approach)"""
flat_shape = (shape[0],[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v # pick the one with the correct shape
q = q.reshape(shape)
return tf.constant(q[:shape[0], :shape[1]], dtype)
return orthogonal
def offset_initializer():
def _initializer(shape, dtype=tf.float32, partition_info=None):
size = shape[0]
assert size % 4 == 0
size = size // 4
res = [np.ones((size)), np.zeros((size*3))]
return tf.constant(np.concatenate(res, axis=0), dtype)
return _initializer
Copy link

batzner commented Mar 17, 2018

Thank you! I'm trying to figure out how to integrate Batch Normalization into a regular RNN cell and this was really helpful.

Copy link

@batzner Have you tried to test this cell's performance? I have tried to implement a similar one, but the performance is so bad in training data.
code like this:

def bn(name,
    return tf.layers.batch_normalization(inputs,
class BNLSTMCell(RNNCell):
      Applying batch normalization in LSTM cell
    def __init__(self,
        self.num_timestamp = 0
        self.num_units = num_units = training
        self.gate_initializer = gate_initializer
        self.bias_initializer = bias_initializer

    def state_size(self):
        :return: (cell_size, hidden_size)
        return (self.num_units, self.num_units)

    def output_size(self):
        :return: output_size
        return self.num_units

    def __call__(self, inputs, state, scope=None):
        :param inputs:  2d tensor with shape [batch_size, feature_size].
        :param state:  Because self.state_size is tuple, so state is tuple with shape
                       [batch_size, s] for s in self.state_size.
        :param scope: scope for create subgraph, default value is class name.
        :return: output: 2d tensor with shape [batch_size, self.output_size]
                 new_state: Next state, [cell_state, hidden_state]
        if scope is None:
            scope = type(self).__name__
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            cell_state, hidden_state = state
            feature_size = inputs.get_shape().as_list()[1]

            W_XH = tf.get_variable('W_XH', shape=[feature_size, 4 * self.num_units],\
            W_HH = tf.get_variable('W_HH', shape=[self.num_units, 4 * self.num_units],\
            bias = tf.get_variable('bias', shape=[4 * self.num_units], \

            hidden = bn('bn_xWXH' + str(self.num_timestamp), tf.matmul(inputs, W_XH), + \
                     bn('bn_hWHH' + str(self.num_timestamp), tf.matmul(hidden_state, W_HH), + \

            #hidden = tf.matmul(inputs, W_XH) + tf.matmul(hidden_state, W_HH) + bias

            forget, input, output, candidate_cell_state = tf.split(hidden, num_or_size_splits=4, axis=1)
            next_cell_state = tf.math.sigmoid(forget) * cell_state + \
                              tf.math.sigmoid(input) * tf.math.tanh(candidate_cell_state)

            bn_next_cell_state = bn('bn_next_cell_state' + str(self.num_timestamp), next_cell_state,

            #bn_next_cell_state = next_cell_state

            next_hidden_state = tf.math.sigmoid(output) * tf.math.tanh(bn_next_cell_state)
            self.num_timestamp = self.num_timestamp + 1
            print('training in __call__ function: {}'.format(

            return next_hidden_state, (next_cell_state, next_hidden_state)

Copy link

Copy link

I have tried this method. but using population statistics (training: False) at test time gives worse results than batch statistics.
I think bn-lstm doesn't work.

Copy link

@spitis Maybe I have made some stupid mistake.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment