Skip to content

Instantly share code, notes, and snippets.

Last active July 6, 2019 08:33
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save eliorc/7095070fb371a41eb3151d4cf73b25d2 to your computer and use it in GitHub Desktop.
import tensorflow as tf
class LayerNormalization(tf.keras.layers.Layer):
Apply layer normalization
- `epsilon` (``float``): Small number to avoid division by zero
- `name` (``str``): Layer name
Input shape
Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when
using this layer as the first layer in a model.
Output shape
Same shape as input.
.. code-block:: python3
import tensorflow as tf
import tavolo as tvl
model = tf.keras.Sequential([SomeLayer(),
tvl.normalization.LayerNormalization()]) # Apply layer normalization on SomeLayer's output
`Layer Normalization`_
.. _Layer Normalization:
def __init__(self, epsilon: float = 1e-8,
name: str = 'layer_normalization',
:param epsilon: Small number to avoid division by zero
:param name: Layer name
super().__init__(name=name, **kwargs)
self.epsilon = epsilon
self.beta, self.gamma = None, None
def build(self, input_shape):
params_shape = input_shape[-1:]
# Initialize beta and gamma
self.beta = self.add_variable('beta',
self.gamma = self.add_variable('gamma',
def compute_mask(self, inputs, mask=None):
return mask
def call(self, inputs,
**kwargs) -> tf.Tensor:
# Calculate mean and variance
mean, variance = tf.nn.moments(inputs, axes=-1, keepdims=True) # shape=(batch_size, 1)
# Normalize
normalized = (inputs - mean) / ((variance + self.epsilon) ** .5) # shape=(batch_size, channels)
return self.gamma * normalized + self.beta # shape=(batch_size, channels)
def get_config(self):
base_config = super().get_config()
base_config['epsilon'] = self.epsilon
return base_config
def from_config(cls, config: dict):
return cls(**config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment