Skip to content

Instantly share code, notes, and snippets.

@spitis
Created February 2, 2017 03:05
Show Gist options
  • Star 50 You must be signed in to star a gist
  • Fork 19 You must be signed in to fork a gist
  • Save spitis/27ab7d2a30bbaf5ef431b4a02194ac60 to your computer and use it in GitHub Desktop.
Save spitis/27ab7d2a30bbaf5ef431b4a02194ac60 to your computer and use it in GitHub Desktop.
Batch normalized LSTM Cell for Tensorflow
"""adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state"""
import tensorflow as tf, numpy as np
RNNCell = tf.nn.rnn_cell.RNNCell
class BNLSTMCell(RNNCell):
'''Batch normalized LSTM as described in arxiv.org/abs/1603.09025'''
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95):
"""
* max bn steps is the maximum number of steps for which to store separate population stats
"""
self._num_units = num_units
self._training = is_training_tensor
self._max_bn_steps = max_bn_steps
self._activation = activation
self._decay = decay
self._initial_scale = 0.1
@property
def state_size(self):
return (self._num_units, self._num_units, 1)
@property
def output_size(self):
return self._num_units
def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False):
'''Assume 2d [batch, values] tensor'''
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[1]
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale))
if no_offset:
offset = 0
elif set_forget_gate_bias:
offset = tf.get_variable('offset', [size], initializer=offset_initializer())
else:
offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer)
pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False)
pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False)
step = tf.minimum(step, self._max_bn_steps - 1)
pop_mean = pop_mean_all_steps[step]
pop_var = pop_var_all_steps[step]
batch_mean, batch_var = tf.nn.moments(x, [0])
def batch_statistics():
pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay)
pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay)
with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(self._training, batch_statistics, population_statistics)
def __call__(self, x, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c, h, step = state
_step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0))
x_size = x.get_shape().as_list()[1]
W_xh = tf.get_variable('W_xh',
[x_size, 4 * self._num_units],
initializer=orthogonal_lstm_initializer())
W_hh = tf.get_variable('W_hh',
[self._num_units, 4 * self._num_units],
initializer=orthogonal_lstm_initializer())
hh = tf.matmul(h, W_hh)
xh = tf.matmul(x, W_xh)
bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True)
bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True)
hidden = bn_xh + bn_hh
f, i, o, j = tf.split(1, 4, hidden)
new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j)
bn_new_c = self._batch_norm(new_c, 'c', _step)
new_h = self._activation(bn_new_c) * tf.sigmoid(o)
return new_h, (new_c, new_h, step+1)
def orthogonal_lstm_initializer():
def orthogonal(shape, dtype=tf.float32, partition_info=None):
# taken from https://github.com/cooijmanstim/recurrent-batch-normalization
# taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a
""" benanne lasagne ortho init (faster than qr approach)"""
flat_shape = (shape[0], np.prod(shape[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v # pick the one with the correct shape
q = q.reshape(shape)
return tf.constant(q[:shape[0], :shape[1]], dtype)
return orthogonal
def offset_initializer():
def _initializer(shape, dtype=tf.float32, partition_info=None):
size = shape[0]
assert size % 4 == 0
size = size // 4
res = [np.ones((size)), np.zeros((size*3))]
return tf.constant(np.concatenate(res, axis=0), dtype)
return _initializer
@batzner
Copy link

batzner commented Mar 17, 2018

Thank you! I'm trying to figure out how to integrate Batch Normalization into a regular RNN cell and this was really helpful.

@liangsun-ponyai
Copy link

@batzner Have you tried to test this cell's performance? I have tried to implement a similar one, but the performance is so bad in training data.
code like this:

def bn(name,
       inputs,
       trainable=True,
       training=True,
       reuse=tf.AUTO_REUSE):
    return tf.layers.batch_normalization(inputs,
                                         trainable=trainable,
                                         training=training,
                                         name=name,
                                         reuse=reuse)
class BNLSTMCell(RNNCell):
    '''
      Applying batch normalization in LSTM cell
    '''
    def __init__(self,
                 num_units,
                 training,
                 gate_initializer=tf.initializers.orthogonal(),
                 bias_initializer=tf.initializers.ones()):
        self.num_timestamp = 0
        self.num_units = num_units
        self.training = training
        self.gate_initializer = gate_initializer
        self.bias_initializer = bias_initializer

    @property
    def state_size(self):
        '''
        :return: (cell_size, hidden_size)
        '''
        return (self.num_units, self.num_units)

    @property
    def output_size(self):
        '''
        :return: output_size
        '''
        return self.num_units

    def __call__(self, inputs, state, scope=None):
        '''
        :param inputs:  2d tensor with shape [batch_size, feature_size].
        :param state:  Because self.state_size is tuple, so state is tuple with shape
                       [batch_size, s] for s in self.state_size.
        :param scope: scope for create subgraph, default value is class name.
        :return: output: 2d tensor with shape [batch_size, self.output_size]
                 new_state: Next state, [cell_state, hidden_state]
        '''
        if scope is None:
            scope = type(self).__name__
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            cell_state, hidden_state = state
            feature_size = inputs.get_shape().as_list()[1]

            W_XH = tf.get_variable('W_XH', shape=[feature_size, 4 * self.num_units],\
                                   initializer=self.gate_initializer)
            W_HH = tf.get_variable('W_HH', shape=[self.num_units, 4 * self.num_units],\
                                   initializer=self.gate_initializer)
            bias = tf.get_variable('bias', shape=[4 * self.num_units], \
                                   initializer=self.bias_initializer)

            hidden = bn('bn_xWXH' + str(self.num_timestamp), tf.matmul(inputs, W_XH), training=self.training) + \
                     bn('bn_hWHH' + str(self.num_timestamp), tf.matmul(hidden_state, W_HH), training=self.training) + \
                     bias

            #hidden = tf.matmul(inputs, W_XH) + tf.matmul(hidden_state, W_HH) + bias


            forget, input, output, candidate_cell_state = tf.split(hidden, num_or_size_splits=4, axis=1)
            next_cell_state = tf.math.sigmoid(forget) * cell_state + \
                              tf.math.sigmoid(input) * tf.math.tanh(candidate_cell_state)

            bn_next_cell_state = bn('bn_next_cell_state' + str(self.num_timestamp), next_cell_state, training=self.training)

            #bn_next_cell_state = next_cell_state

            next_hidden_state = tf.math.sigmoid(output) * tf.math.tanh(bn_next_cell_state)
            self.num_timestamp = self.num_timestamp + 1
            print('training in __call__ function: {}'.format(self.training))

            return next_hidden_state, (next_cell_state, next_hidden_state)

@liangsun-ponyai
Copy link

@liangsun-ponyai
Copy link

I have tried this method. but using population statistics (training: False) at test time gives worse results than batch statistics.
I think bn-lstm doesn't work.

@liangsun-ponyai
Copy link

@spitis Maybe I have made some stupid mistake.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment