Skip to content

Instantly share code, notes, and snippets.

Created February 2, 2017 03:05
Show Gist options
  • Save spitis/27ab7d2a30bbaf5ef431b4a02194ac60 to your computer and use it in GitHub Desktop.
Save spitis/27ab7d2a30bbaf5ef431b4a02194ac60 to your computer and use it in GitHub Desktop.
Batch normalized LSTM Cell for Tensorflow
"""adapted from to store separate population statistics per state"""
import tensorflow as tf, numpy as np
RNNCell = tf.nn.rnn_cell.RNNCell
class BNLSTMCell(RNNCell):
'''Batch normalized LSTM as described in'''
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95):
* max bn steps is the maximum number of steps for which to store separate population stats
self._num_units = num_units
self._training = is_training_tensor
self._max_bn_steps = max_bn_steps
self._activation = activation
self._decay = decay
self._initial_scale = 0.1
def state_size(self):
return (self._num_units, self._num_units, 1)
def output_size(self):
return self._num_units
def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False):
'''Assume 2d [batch, values] tensor'''
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[1]
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale))
if no_offset:
offset = 0
elif set_forget_gate_bias:
offset = tf.get_variable('offset', [size], initializer=offset_initializer())
offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer)
pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False)
pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False)
step = tf.minimum(step, self._max_bn_steps - 1)
pop_mean = pop_mean_all_steps[step]
pop_var = pop_var_all_steps[step]
batch_mean, batch_var = tf.nn.moments(x, [0])
def batch_statistics():
pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay)
pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay)
with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(self._training, batch_statistics, population_statistics)
def __call__(self, x, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c, h, step = state
_step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0))
x_size = x.get_shape().as_list()[1]
W_xh = tf.get_variable('W_xh',
[x_size, 4 * self._num_units],
W_hh = tf.get_variable('W_hh',
[self._num_units, 4 * self._num_units],
hh = tf.matmul(h, W_hh)
xh = tf.matmul(x, W_xh)
bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True)
bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True)
hidden = bn_xh + bn_hh
f, i, o, j = tf.split(1, 4, hidden)
new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j)
bn_new_c = self._batch_norm(new_c, 'c', _step)
new_h = self._activation(bn_new_c) * tf.sigmoid(o)
return new_h, (new_c, new_h, step+1)
def orthogonal_lstm_initializer():
def orthogonal(shape, dtype=tf.float32, partition_info=None):
# taken from
# taken from
""" benanne lasagne ortho init (faster than qr approach)"""
flat_shape = (shape[0],[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v # pick the one with the correct shape
q = q.reshape(shape)
return tf.constant(q[:shape[0], :shape[1]], dtype)
return orthogonal
def offset_initializer():
def _initializer(shape, dtype=tf.float32, partition_info=None):
size = shape[0]
assert size % 4 == 0
size = size // 4
res = [np.ones((size)), np.zeros((size*3))]
return tf.constant(np.concatenate(res, axis=0), dtype)
return _initializer
Copy link

batzner commented Mar 17, 2018

Thank you! I'm trying to figure out how to integrate Batch Normalization into a regular RNN cell and this was really helpful.

Copy link

@batzner Have you tried to test this cell's performance? I have tried to implement a similar one, but the performance is so bad in training data.
code like this:

def bn(name,
    return tf.layers.batch_normalization(inputs,
class BNLSTMCell(RNNCell):
      Applying batch normalization in LSTM cell
    def __init__(self,
        self.num_timestamp = 0
        self.num_units = num_units = training
        self.gate_initializer = gate_initializer
        self.bias_initializer = bias_initializer

    def state_size(self):
        :return: (cell_size, hidden_size)
        return (self.num_units, self.num_units)

    def output_size(self):
        :return: output_size
        return self.num_units

    def __call__(self, inputs, state, scope=None):
        :param inputs:  2d tensor with shape [batch_size, feature_size].
        :param state:  Because self.state_size is tuple, so state is tuple with shape
                       [batch_size, s] for s in self.state_size.
        :param scope: scope for create subgraph, default value is class name.
        :return: output: 2d tensor with shape [batch_size, self.output_size]
                 new_state: Next state, [cell_state, hidden_state]
        if scope is None:
            scope = type(self).__name__
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            cell_state, hidden_state = state
            feature_size = inputs.get_shape().as_list()[1]

            W_XH = tf.get_variable('W_XH', shape=[feature_size, 4 * self.num_units],\
            W_HH = tf.get_variable('W_HH', shape=[self.num_units, 4 * self.num_units],\
            bias = tf.get_variable('bias', shape=[4 * self.num_units], \

            hidden = bn('bn_xWXH' + str(self.num_timestamp), tf.matmul(inputs, W_XH), + \
                     bn('bn_hWHH' + str(self.num_timestamp), tf.matmul(hidden_state, W_HH), + \

            #hidden = tf.matmul(inputs, W_XH) + tf.matmul(hidden_state, W_HH) + bias

            forget, input, output, candidate_cell_state = tf.split(hidden, num_or_size_splits=4, axis=1)
            next_cell_state = tf.math.sigmoid(forget) * cell_state + \
                              tf.math.sigmoid(input) * tf.math.tanh(candidate_cell_state)

            bn_next_cell_state = bn('bn_next_cell_state' + str(self.num_timestamp), next_cell_state,

            #bn_next_cell_state = next_cell_state

            next_hidden_state = tf.math.sigmoid(output) * tf.math.tanh(bn_next_cell_state)
            self.num_timestamp = self.num_timestamp + 1
            print('training in __call__ function: {}'.format(

            return next_hidden_state, (next_cell_state, next_hidden_state)

Copy link

Copy link

I have tried this method. but using population statistics (training: False) at test time gives worse results than batch statistics.
I think bn-lstm doesn't work.

Copy link

@spitis Maybe I have made some stupid mistake.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment