Skip to content

Instantly share code, notes, and snippets.

@blorgblerg
Last active September 21, 2017 14:47
Show Gist options
  • Save blorgblerg/d6bfcd93d143e2f73fb763f38b99f0ce to your computer and use it in GitHub Desktop.
Save blorgblerg/d6bfcd93d143e2f73fb763f38b99f0ce to your computer and use it in GitHub Desktop.
def batch_normify(batch, depth=1):
"""Returns a normalized batch.
Inputs:
-batch: a batch tensor
-depth: the dimension of the axis you want to keep unnormalized"""
with tf.variable_scope('bn'):
beta = tf.Variable(tf.constant(0.0, shape=[depth]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[depth]),
name='gamma', trainable=True)
batch_mean, batch_var = tf.nn.moments(batch, [0], name='moments') #Axis to normalize across.
batch_normed = tf.nn.batch_normalization(
batch,
batch_mean,
batch_var,
beta,
gamma,
0.0001,
name='batch_normification'
)
return batch_normed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment