Created
February 2, 2017 03:05
-
-
Save spitis/27ab7d2a30bbaf5ef431b4a02194ac60 to your computer and use it in GitHub Desktop.
Batch normalized LSTM Cell for Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state""" | |
import tensorflow as tf, numpy as np | |
RNNCell = tf.nn.rnn_cell.RNNCell | |
class BNLSTMCell(RNNCell): | |
'''Batch normalized LSTM as described in arxiv.org/abs/1603.09025''' | |
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95): | |
""" | |
* max bn steps is the maximum number of steps for which to store separate population stats | |
""" | |
self._num_units = num_units | |
self._training = is_training_tensor | |
self._max_bn_steps = max_bn_steps | |
self._activation = activation | |
self._decay = decay | |
self._initial_scale = 0.1 | |
@property | |
def state_size(self): | |
return (self._num_units, self._num_units, 1) | |
@property | |
def output_size(self): | |
return self._num_units | |
def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False): | |
'''Assume 2d [batch, values] tensor''' | |
with tf.variable_scope(name_scope): | |
size = x.get_shape().as_list()[1] | |
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale)) | |
if no_offset: | |
offset = 0 | |
elif set_forget_gate_bias: | |
offset = tf.get_variable('offset', [size], initializer=offset_initializer()) | |
else: | |
offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer) | |
pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False) | |
pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False) | |
step = tf.minimum(step, self._max_bn_steps - 1) | |
pop_mean = pop_mean_all_steps[step] | |
pop_var = pop_var_all_steps[step] | |
batch_mean, batch_var = tf.nn.moments(x, [0]) | |
def batch_statistics(): | |
pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay) | |
pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay) | |
with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]): | |
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon) | |
def population_statistics(): | |
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon) | |
return tf.cond(self._training, batch_statistics, population_statistics) | |
def __call__(self, x, state, scope=None): | |
with tf.variable_scope(scope or type(self).__name__): | |
c, h, step = state | |
_step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0)) | |
x_size = x.get_shape().as_list()[1] | |
W_xh = tf.get_variable('W_xh', | |
[x_size, 4 * self._num_units], | |
initializer=orthogonal_lstm_initializer()) | |
W_hh = tf.get_variable('W_hh', | |
[self._num_units, 4 * self._num_units], | |
initializer=orthogonal_lstm_initializer()) | |
hh = tf.matmul(h, W_hh) | |
xh = tf.matmul(x, W_xh) | |
bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True) | |
bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True) | |
hidden = bn_xh + bn_hh | |
f, i, o, j = tf.split(1, 4, hidden) | |
new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j) | |
bn_new_c = self._batch_norm(new_c, 'c', _step) | |
new_h = self._activation(bn_new_c) * tf.sigmoid(o) | |
return new_h, (new_c, new_h, step+1) | |
def orthogonal_lstm_initializer(): | |
def orthogonal(shape, dtype=tf.float32, partition_info=None): | |
# taken from https://github.com/cooijmanstim/recurrent-batch-normalization | |
# taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a | |
""" benanne lasagne ortho init (faster than qr approach)""" | |
flat_shape = (shape[0], np.prod(shape[1:])) | |
a = np.random.normal(0.0, 1.0, flat_shape) | |
u, _, v = np.linalg.svd(a, full_matrices=False) | |
q = u if u.shape == flat_shape else v # pick the one with the correct shape | |
q = q.reshape(shape) | |
return tf.constant(q[:shape[0], :shape[1]], dtype) | |
return orthogonal | |
def offset_initializer(): | |
def _initializer(shape, dtype=tf.float32, partition_info=None): | |
size = shape[0] | |
assert size % 4 == 0 | |
size = size // 4 | |
res = [np.ones((size)), np.zeros((size*3))] | |
return tf.constant(np.concatenate(res, axis=0), dtype) | |
return _initializer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@spitis Maybe I have made some stupid mistake.