Skip to content

Instantly share code, notes, and snippets.

@Cospel
Created January 26, 2020 07:57
Show Gist options
  • Save Cospel/c895fcd717da7260d76c5d6397cada94 to your computer and use it in GitHub Desktop.
Save Cospel/c895fcd717da7260d76c5d6397cada94 to your computer and use it in GitHub Desktop.
batch_norm.py
import tensorflow as tf
class BatchNormalization(tf.keras.layers.BatchNormalization):
"""
Replace BatchNormalization layers with this new layer.
This layer has fixed momentum 0.9.
"""
def __init__(self, momentum=0.9, name=None, **kwargs):
super(BatchNormalization, self).__init__(momentum=0.9, name=name, **kwargs)
def call(self, inputs, training=None):
return super().call(inputs=inputs, training=training)
def get_config(self):
config = super(BatchNormalization, self).get_config()
return config
tf.keras.layers.BatchNormalization = BatchNormalization
base_model = tf.keras.applications.MobileNetV2(
weights="imagenet", input_shape=self.shape, include_top=False, layers=tf.keras.layers
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment