Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Batch normalized LSTM Cell for Tensorflow
"""adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state"""
import tensorflow as tf, numpy as np
RNNCell = tf.nn.rnn_cell.RNNCell
class BNLSTMCell(RNNCell):
'''Batch normalized LSTM as described in arxiv.org/abs/1603.09025'''
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95):
"""
* max bn steps is the maximum number of steps for which to store separate population stats
"""
self._num_units = num_units
self._training = is_training_tensor
self._max_bn_steps = max_bn_steps
self._activation = activation
self._decay = decay
self._initial_scale = 0.1
@property
def state_size(self):
return (self._num_units, self._num_units, 1)
@property
def output_size(self):
return self._num_units
def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False):
'''Assume 2d [batch, values] tensor'''
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[1]
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale))
if no_offset:
offset = 0
elif set_forget_gate_bias:
offset = tf.get_variable('offset', [size], initializer=offset_initializer())
else:
offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer)
pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False)
pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False)
step = tf.minimum(step, self._max_bn_steps - 1)
pop_mean = pop_mean_all_steps[step]
pop_var = pop_var_all_steps[step]
batch_mean, batch_var = tf.nn.moments(x, [0])
def batch_statistics():
pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay)
pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay)
with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(self._training, batch_statistics, population_statistics)
def __call__(self, x, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c, h, step = state
_step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0))
x_size = x.get_shape().as_list()[1]
W_xh = tf.get_variable('W_xh',
[x_size, 4 * self._num_units],
initializer=orthogonal_lstm_initializer())
W_hh = tf.get_variable('W_hh',
[self._num_units, 4 * self._num_units],
initializer=orthogonal_lstm_initializer())
hh = tf.matmul(h, W_hh)
xh = tf.matmul(x, W_xh)
bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True)
bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True)
hidden = bn_xh + bn_hh
f, i, o, j = tf.split(1, 4, hidden)
new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j)
bn_new_c = self._batch_norm(new_c, 'c', _step)
new_h = self._activation(bn_new_c) * tf.sigmoid(o)
return new_h, (new_c, new_h, step+1)
def orthogonal_lstm_initializer():
def orthogonal(shape, dtype=tf.float32, partition_info=None):
# taken from https://github.com/cooijmanstim/recurrent-batch-normalization
# taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a
""" benanne lasagne ortho init (faster than qr approach)"""
flat_shape = (shape[0], np.prod(shape[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v # pick the one with the correct shape
q = q.reshape(shape)
return tf.constant(q[:shape[0], :shape[1]], dtype)
return orthogonal
def offset_initializer():
def _initializer(shape, dtype=tf.float32, partition_info=None):
size = shape[0]
assert size % 4 == 0
size = size // 4
res = [np.ones((size)), np.zeros((size*3))]
return tf.constant(np.concatenate(res, axis=0), dtype)
return _initializer
@batzner

This comment has been minimized.

Copy link

batzner commented Mar 17, 2018

Thank you! I'm trying to figure out how to integrate Batch Normalization into a regular RNN cell and this was really helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.