Created
July 6, 2019 08:57
-
-
Save eliorc/6ac98485b0606045f2412a587165176a to your computer and use it in GitHub Desktop.
LayerNormalization test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tavolo.normalization import LayerNormalization | |
def test_shapes(): | |
""" Test input-output shapes """ | |
# Inputs shape | |
input_shape_2d = (56, 10) | |
input_shape_3d = (56, 10, 30) | |
inputs_2d = tf.random.normal(shape=input_shape_2d) | |
inputs_3d = tf.random.normal(shape=input_shape_3d) | |
layer_norm_2d = LayerNormalization(name='layer_norm_2d') | |
layer_norm_3d = LayerNormalization(name='layer_norm_3d') | |
output_2d, output_3d = layer_norm_2d(inputs_2d), layer_norm_3d(inputs_3d) | |
# Assert correctness of output shapes | |
assert output_2d.shape == input_shape_2d | |
assert output_3d.shape == input_shape_3d | |
def test_masking(): | |
""" Test masking support """ | |
# Input | |
input_shape_3d = (56, 10, 30) | |
inputs_3d = tf.random.normal(shape=input_shape_3d) | |
mask = tf.less(tf.reduce_sum(tf.reduce_sum(inputs_3d, axis=-1, keepdims=True), axis=-1, keepdims=True), 0) | |
masked_input = tf.where(tf.broadcast_to(mask, input_shape_3d), tf.zeros_like(inputs_3d), inputs_3d) | |
# Layers | |
masking_layer = tf.keras.layers.Masking(mask_value=0., input_shape=input_shape_3d[1:]) | |
layer_norm_3d = LayerNormalization(name='layer_norm_3d') | |
result = layer_norm_3d(masking_layer(masked_input)) | |
assert result.shape == input_shape_3d | |
def test_logic(): | |
""" Test logic on known input """ | |
# Input | |
input_shape_2d = (56, 10) | |
inputs_2d = tf.ones(shape=input_shape_2d) | |
layer_norm_2d = LayerNormalization(name='layer_norm_2d') | |
# Assert output correctness | |
assert tf.reduce_sum(layer_norm_2d(inputs_2d)).numpy() == 0 | |
def test_serialization(): | |
""" Test layer serialization (get_config, from_config) """ | |
simple = LayerNormalization() | |
restored = LayerNormalization.from_config(simple.get_config()) | |
assert restored.get_config() == simple.get_config() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment