Skip to content

Instantly share code, notes, and snippets.

@Luonic
Created June 15, 2018 09:57
Show Gist options
  • Save Luonic/4d1898bf0c209f0fd3cbc68b776768b1 to your computer and use it in GitHub Desktop.
Save Luonic/4d1898bf0c209f0fd3cbc68b776768b1 to your computer and use it in GitHub Desktop.
Group normalization for Tensorflow with adaptive groups count and data format 'channels_last'
def group_normalization(input_tensor, num_groups, gamma=1.0, beta=0.0, epsilon=1e-5):
channels_int = input_tensor.get_shape().as_list()[3]
while channels_int % num_groups != 0 and num_groups != 0:
num_groups -= 1
batch, height, width, channels = input_tensor.shape
input_tensor = tf.reshape(input_tensor, shape=(batch, height, width, channels // num_groups, num_groups))
mean, var = tf.nn.moments(input_tensor, [1, 2, 3], keep_dims=True)
input_tensor = (input_tensor - mean) / tf.sqrt(var + epsilon)
input_tensor = tf.reshape(input_tensor, [batch, height, width, channels])
return input_tensor * gamma + beta
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment