Skip to content

Instantly share code, notes, and snippets.

@wwiiiii
Created February 19, 2018 04:17
Show Gist options
  • Save wwiiiii/67491cb54a37a8a5a7eaf626d71526e0 to your computer and use it in GitHub Desktop.
Save wwiiiii/67491cb54a37a8a5a7eaf626d71526e0 to your computer and use it in GitHub Desktop.
def _batch_normalization(inp, is_train, name=None, is_conv=True):
now_mean, now_var = tf.nn.moments(inp, axes=[0, 1, 2] if is_conv else [0])
if name is None:
name = str(time.time())
gamma = tf.get_variable('gamma_%s' % name, shape=[inp.shape[-1]])
beta = tf.get_variable('beta_%s' % name, shape=[inp.shape[-1]])
ema = tf.train.ExponentialMovingAverage(decay=0.99)
def update():
with tf.control_dependencies([ema.apply([now_mean, now_var])]):
return tf.identity(now_mean), tf.identity(now_var)
mean, var = tf.cond(is_train, update, lambda: (ema.average(now_mean), ema.average(now_var)))
return tf.nn.batch_normalization(inp, mean, var, beta, gamma, 0.001)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment