Last active
July 6, 2019 08:33
-
-
Save eliorc/7095070fb371a41eb3151d4cf73b25d2 to your computer and use it in GitHub Desktop.
LayerNormalization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
class LayerNormalization(tf.keras.layers.Layer): | |
""" | |
Apply layer normalization | |
Arguments | |
--------- | |
- `epsilon` (``float``): Small number to avoid division by zero | |
- `name` (``str``): Layer name | |
Input shape | |
----------- | |
Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when | |
using this layer as the first layer in a model. | |
Output shape | |
------------ | |
Same shape as input. | |
Examples | |
-------- | |
.. code-block:: python3 | |
import tensorflow as tf | |
import tavolo as tvl | |
model = tf.keras.Sequential([SomeLayer(), | |
tvl.normalization.LayerNormalization()]) # Apply layer normalization on SomeLayer's output | |
References | |
---------- | |
`Layer Normalization`_ | |
.. _Layer Normalization: | |
https://arxiv.org/pdf/1607.06450 | |
""" | |
def __init__(self, epsilon: float = 1e-8, | |
name: str = 'layer_normalization', | |
**kwargs): | |
""" | |
:param epsilon: Small number to avoid division by zero | |
:param name: Layer name | |
""" | |
super().__init__(name=name, **kwargs) | |
self.epsilon = epsilon | |
self.beta, self.gamma = None, None | |
def build(self, input_shape): | |
params_shape = input_shape[-1:] | |
# Initialize beta and gamma | |
self.beta = self.add_variable('beta', | |
shape=params_shape, | |
initializer=tf.keras.initializers.zeros, | |
dtype=self.dtype) | |
self.gamma = self.add_variable('gamma', | |
shape=params_shape, | |
initializer=tf.keras.initializers.ones, | |
dtype=self.dtype) | |
super().build(input_shape) | |
def compute_mask(self, inputs, mask=None): | |
return mask | |
def call(self, inputs, | |
**kwargs) -> tf.Tensor: | |
# Calculate mean and variance | |
mean, variance = tf.nn.moments(inputs, axes=-1, keepdims=True) # shape=(batch_size, 1) | |
# Normalize | |
normalized = (inputs - mean) / ((variance + self.epsilon) ** .5) # shape=(batch_size, channels) | |
return self.gamma * normalized + self.beta # shape=(batch_size, channels) | |
def get_config(self): | |
base_config = super().get_config() | |
base_config['epsilon'] = self.epsilon | |
return base_config | |
@classmethod | |
def from_config(cls, config: dict): | |
return cls(**config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment