Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Batch normalized LSTM Cell for Tensorflow
"""adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state"""
import tensorflow as tf, numpy as np
RNNCell = tf.nn.rnn_cell.RNNCell
class BNLSTMCell(RNNCell):
'''Batch normalized LSTM as described in arxiv.org/abs/1603.09025'''
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95):
"""
* max bn steps is the maximum number of steps for which to store separate population stats
"""
self._num_units = num_units
self._training = is_training_tensor
self._max_bn_steps = max_bn_steps
self._activation = activation
self._decay = decay
self._initial_scale = 0.1
@property
def state_size(self):
return (self._num_units, self._num_units, 1)
@property
def output_size(self):
return self._num_units
def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False):
'''Assume 2d [batch, values] tensor'''
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[1]
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale))
if no_offset:
offset = 0
elif set_forget_gate_bias:
offset = tf.get_variable('offset', [size], initializer=offset_initializer())
else:
offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer)
pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False)
pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False)
step = tf.minimum(step, self._max_bn_steps - 1)
pop_mean = pop_mean_all_steps[step]
pop_var = pop_var_all_steps[step]
batch_mean, batch_var = tf.nn.moments(x, [0])
def batch_statistics():
pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay)
pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay)
with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(self._training, batch_statistics, population_statistics)
def __call__(self, x, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c, h, step = state
_step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0))
x_size = x.get_shape().as_list()[1]
W_xh = tf.get_variable('W_xh',
[x_size, 4 * self._num_units],
initializer=orthogonal_lstm_initializer())
W_hh = tf.get_variable('W_hh',
[self._num_units, 4 * self._num_units],
initializer=orthogonal_lstm_initializer())
hh = tf.matmul(h, W_hh)
xh = tf.matmul(x, W_xh)
bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True)
bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True)
hidden = bn_xh + bn_hh
f, i, o, j = tf.split(1, 4, hidden)
new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j)
bn_new_c = self._batch_norm(new_c, 'c', _step)
new_h = self._activation(bn_new_c) * tf.sigmoid(o)
return new_h, (new_c, new_h, step+1)
def orthogonal_lstm_initializer():
def orthogonal(shape, dtype=tf.float32, partition_info=None):
# taken from https://github.com/cooijmanstim/recurrent-batch-normalization
# taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a
""" benanne lasagne ortho init (faster than qr approach)"""
flat_shape = (shape[0], np.prod(shape[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v # pick the one with the correct shape
q = q.reshape(shape)
return tf.constant(q[:shape[0], :shape[1]], dtype)
return orthogonal
def offset_initializer():
def _initializer(shape, dtype=tf.float32, partition_info=None):
size = shape[0]
assert size % 4 == 0
size = size // 4
res = [np.ones((size)), np.zeros((size*3))]
return tf.constant(np.concatenate(res, axis=0), dtype)
return _initializer
@batzner

This comment has been minimized.

Copy link

batzner commented Mar 17, 2018

Thank you! I'm trying to figure out how to integrate Batch Normalization into a regular RNN cell and this was really helpful.

@liangsun-ponyai

This comment has been minimized.

Copy link

liangsun-ponyai commented Apr 1, 2020

@batzner Have you tried to test this cell's performance? I have tried to implement a similar one, but the performance is so bad in training data.
code like this:

def bn(name,
       inputs,
       trainable=True,
       training=True,
       reuse=tf.AUTO_REUSE):
    return tf.layers.batch_normalization(inputs,
                                         trainable=trainable,
                                         training=training,
                                         name=name,
                                         reuse=reuse)
class BNLSTMCell(RNNCell):
    '''
      Applying batch normalization in LSTM cell
    '''
    def __init__(self,
                 num_units,
                 training,
                 gate_initializer=tf.initializers.orthogonal(),
                 bias_initializer=tf.initializers.ones()):
        self.num_timestamp = 0
        self.num_units = num_units
        self.training = training
        self.gate_initializer = gate_initializer
        self.bias_initializer = bias_initializer

    @property
    def state_size(self):
        '''
        :return: (cell_size, hidden_size)
        '''
        return (self.num_units, self.num_units)

    @property
    def output_size(self):
        '''
        :return: output_size
        '''
        return self.num_units

    def __call__(self, inputs, state, scope=None):
        '''
        :param inputs:  2d tensor with shape [batch_size, feature_size].
        :param state:  Because self.state_size is tuple, so state is tuple with shape
                       [batch_size, s] for s in self.state_size.
        :param scope: scope for create subgraph, default value is class name.
        :return: output: 2d tensor with shape [batch_size, self.output_size]
                 new_state: Next state, [cell_state, hidden_state]
        '''
        if scope is None:
            scope = type(self).__name__
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            cell_state, hidden_state = state
            feature_size = inputs.get_shape().as_list()[1]

            W_XH = tf.get_variable('W_XH', shape=[feature_size, 4 * self.num_units],\
                                   initializer=self.gate_initializer)
            W_HH = tf.get_variable('W_HH', shape=[self.num_units, 4 * self.num_units],\
                                   initializer=self.gate_initializer)
            bias = tf.get_variable('bias', shape=[4 * self.num_units], \
                                   initializer=self.bias_initializer)

            hidden = bn('bn_xWXH' + str(self.num_timestamp), tf.matmul(inputs, W_XH), training=self.training) + \
                     bn('bn_hWHH' + str(self.num_timestamp), tf.matmul(hidden_state, W_HH), training=self.training) + \
                     bias

            #hidden = tf.matmul(inputs, W_XH) + tf.matmul(hidden_state, W_HH) + bias


            forget, input, output, candidate_cell_state = tf.split(hidden, num_or_size_splits=4, axis=1)
            next_cell_state = tf.math.sigmoid(forget) * cell_state + \
                              tf.math.sigmoid(input) * tf.math.tanh(candidate_cell_state)

            bn_next_cell_state = bn('bn_next_cell_state' + str(self.num_timestamp), next_cell_state, training=self.training)

            #bn_next_cell_state = next_cell_state

            next_hidden_state = tf.math.sigmoid(output) * tf.math.tanh(bn_next_cell_state)
            self.num_timestamp = self.num_timestamp + 1
            print('training in __call__ function: {}'.format(self.training))

            return next_hidden_state, (next_cell_state, next_hidden_state)
@liangsun-ponyai

This comment has been minimized.

Copy link

liangsun-ponyai commented Apr 1, 2020

@liangsun-ponyai

This comment has been minimized.

Copy link

liangsun-ponyai commented Apr 3, 2020

I have tried this method. but using population statistics (training: False) at test time gives worse results than batch statistics.
I think bn-lstm doesn't work.

@liangsun-ponyai

This comment has been minimized.

Copy link

liangsun-ponyai commented Apr 3, 2020

@spitis Maybe I have made some stupid mistake.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.