Skip to content

Instantly share code, notes, and snippets.

@ulf1
Last active January 7, 2020 16:34
Show Gist options
  • Save ulf1/d36051d0a91073f974c9e8b082865657 to your computer and use it in GitHub Desktop.
Save ulf1/d36051d0a91073f974c9e8b082865657 to your computer and use it in GitHub Desktop.
LayernormSimpleRNN for tensorflow
# syntax check, unit test, profiling
flake8>=3.7.8
pytest>=5.3.1
# public packages
numpy>=1.18.0
tensorflow>=2.0.0
# private packages
#-e git+git@github.com:ulf1/pkgname.git#egg=pkgname
# Jupyter
jupyterlab>=1.2.4
#!/bin/bash
python3 -m venv .venv
source .venv/bin/activate
pip3 install --upgrade pip
pip3 install -r requirements.txt
python3 script5.py
flake8 --ignore=E111,E114 script5.py
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin, RNN
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.layers.recurrent import _generate_zero_filled_state_for_cell
# _maybe_reset_cell_dropout_mask, _caching_device
from tensorflow.keras.layers import LayerNormalization # NEW(!)
from tensorflow.keras.models import Sequential
import numpy as np
#@keras_export('keras.experimental.LayernormSimpleRNNCell')
class LayernormSimpleRNNCell(DropoutRNNCellMixin, Layer):
"""Cell class for LayernormSimpleRNN.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNNCell
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
This class processes one step within the whole time sequence input, whereas
`tf.keras.layer.LayernormSimpleRNN` processes the whole sequence.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the linear
transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape of `[batch, feature]`.
states: A 2D tensor with shape of `[batch, units]`, which is the state from
the previous time step. For timestep 0, the initial state provided by user
will be feed to cell.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True))
output = rnn(inputs) # The output has shape `[32, 4]`.
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True),
return_sequences=True,
return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = rnn(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
**kwargs):
self._enable_caching_device = kwargs.pop('enable_caching_device', False)
super(LayernormSimpleRNNCell, self).__init__(**kwargs) # TMP(!)
self.units = units
self.activation = activations.get(activation)
self.use_bias = False if use_layernorm else use_bias # NEW(!)
self.use_layernorm = use_layernorm # NEW(!)
self.layernorm_epsilon = layernorm_epsilon # NEW(!)
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.gamma_initializer = initializers.get(gamma_initializer) # NEW(!)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer) # NEW(!)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.gamma_constraint = constraints.get(gamma_constraint) # NEW(!)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_size = self.units
self.output_size = self.units
#@tf_utils.shape_type_conversion
def build(self, input_shape):
#default_caching_device = _caching_device(self)
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
#caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
#caching_device=default_caching_device)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.units,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
#caching_device=default_caching_device)
else:
self.bias = None
if self.use_layernorm: # vvv NEW(!)
self.layernorm = LayerNormalization(
axis=-1, center=True, scale=True, trainable=True,
name='layernorm',
epsilon=self.layernorm_epsilon,
beta_initializer=self.bias_initializer,
gamma_initializer=self.gamma_initializer,
beta_regularizer=self.bias_regularizer,
gamma_regularizer=self.gamma_regularizer,
beta_constraint=self.bias_constraint,
gamma_constraint=self.gamma_constraint)
else:
self.layernorm = None # ^^^ NEW(!)
self.built = True
def call(self, inputs, states, training=None):
prev_output = states[0]
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
prev_output, training)
if dp_mask is not None:
h = K.dot(inputs * dp_mask, self.kernel)
else:
h = K.dot(inputs, self.kernel)
if self.bias is not None:
h = K.bias_add(h, self.bias)
if rec_dp_mask is not None:
prev_output = prev_output * rec_dp_mask
output = h + K.dot(prev_output, self.recurrent_kernel)
if self.layernorm is not None: # NEW(!)
output = self.layernorm(output) # NEW(!)
if self.activation is not None:
output = self.activation(output)
return output, [output]
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'use_bias':
self.use_bias,
'use_layernorm':
self.use_layernorm, # NEW(!)
'layernorm_epsilon':
self.layernorm_epsilon, # NEW(!)
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'gamma_initializer':
initializers.serialize(self.gamma_initializer), # NEW(!)
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'gamma_regularizer':
regularizers.serialize(self.gamma_regularizer), # NEW(!)
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint),
'gamma_constraint':
constraints.serialize(self.gamma_constraint), # NEW(!)
'dropout':
self.dropout,
'recurrent_dropout':
self.recurrent_dropout
}
base_config = super(LayernormSimpleRNNCell, self).get_config() # TMP(!)
return dict(list(base_config.items()) + list(config.items()))
#@keras_export('keras.experimental.LayernormSimpleRNN')
class LayernormSimpleRNN(RNN):
"""Fully-connected RNN where the output is to be fed back to input.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNN
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the inputs.
Default: 0.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the
recurrent state. Default: 0.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state
in addition to the output. Default: `False`
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
mask: Binary tensor of shape `[batch, timesteps]` indicating whether
a given timestep should be masked.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
model = tf.keras.layers.LayernormSimpleRNN(4, use_layernorm=True)
output = model(inputs) # The output has shape `[32, 4]`.
model = tf.keras.layers.LayernormSimpleRNN(
4, use_layernorm=True, return_sequences=True, return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = model(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs):
if 'implementation' in kwargs:
kwargs.pop('implementation')
logging.warning('The `implementation` argument '
'in `LayernormSimpleRNN` has been deprecated. ' # TMP(!)
'Please remove it from your layer call.')
cell = LayernormSimpleRNNCell( # TMP(!)
units,
activation=activation,
use_bias=use_bias,
use_layernorm=use_layernorm, # NEW(!)
layernorm_epsilon=layernorm_epsilon, # NEW(!)
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
gamma_initializer=gamma_initializer, # NEW(!)
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
gamma_regularizer=gamma_regularizer, # NEW(!)
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
gamma_constraint=gamma_constraint, # NEW(!)
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True))
super(LayernormSimpleRNN, self).__init__( # TMP(!)
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
def call(self, inputs, mask=None, training=None, initial_state=None):
#self._maybe_reset_cell_dropout_mask(self.cell)
return super(LayernormSimpleRNN, self).call( # TMP(!)
inputs, mask=mask, training=training, initial_state=initial_state)
@property
def units(self):
return self.cell.units
@property
def activation(self):
return self.cell.activation
@property
def use_bias(self):
return self.cell.use_bias
@property
def use_layernorm(self):
return self.cell.use_layernorm # NEW(!)
@property
def layernorm_epsilon(self):
return self.cell.layernorm_epsilon # NEW(!)
@property
def kernel_initializer(self):
return self.cell.kernel_initializer
@property
def recurrent_initializer(self):
return self.cell.recurrent_initializer
@property
def bias_initializer(self):
return self.cell.bias_initializer
@property
def gamma_initializer(self):
return self.cell.gamma_initializer # NEW(!)
@property
def kernel_regularizer(self):
return self.cell.kernel_regularizer
@property
def recurrent_regularizer(self):
return self.cell.recurrent_regularizer
@property
def bias_regularizer(self):
return self.cell.bias_regularizer
@property
def gamma_regularizer(self):
return self.cell.gamma_regularizer # NEW(!)
@property
def kernel_constraint(self):
return self.cell.kernel_constraint
@property
def recurrent_constraint(self):
return self.cell.recurrent_constraint
@property
def bias_constraint(self):
return self.cell.bias_constraint
@property
def gamma_constraint(self):
return self.cell.gamma_constraint # NEW(!)
@property
def dropout(self):
return self.cell.dropout
@property
def recurrent_dropout(self):
return self.cell.recurrent_dropout
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'use_bias':
self.use_bias,
'use_layernorm':
self.use_layernorm, # NEW(!)
'layernorm_epsilon':
self.layernorm_epsilon, # NEW(!)
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'gamma_initializer':
initializers.serialize(self.gamma_initializer), # NEW(!)
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'gamma_regularizer':
regularizers.serialize(self.gamma_regularizer), # NEW(!)
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint),
'gamma_constraint':
constraints.serialize(self.gamma_constraint), # NEW(!)
'dropout':
self.dropout,
'recurrent_dropout':
self.recurrent_dropout
}
base_config = super(LayernormSimpleRNN, self).get_config() # TMP(!)
del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
if 'implementation' in config:
config.pop('implementation')
return cls(**config)
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
# create datasets
x = np.random.random((num_samples, timesteps, embedding_dim))
y = np.random.random((num_samples, units))
# modeling
model = Sequential([
LayernormSimpleRNN(
units, use_layernorm=True,
input_shape=(None, embedding_dim))
])
model.compile('rmsprop', 'mse')
# training
model.fit(x, y, verbose=1)
print(model.summary())
cfg = model.get_config()
print(cfg)
#print(model.from_config(cfg))
# https://github.com/tensorflow/tensorflow/pull/35469#issuecomment-570977586
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin, SimpleRNN
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.layers.recurrent import _generate_zero_filled_state_for_cell
# _maybe_reset_cell_dropout_mask, _caching_device
from tensorflow.keras.layers import LayerNormalization # NEW(!)
from tensorflow.keras.models import Sequential
import numpy as np
#@keras_export('keras.experimental.LayernormSimpleRNNCell')
class LayernormSimpleRNNCell(DropoutRNNCellMixin, Layer):
"""Cell class for LayernormSimpleRNN.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNNCell
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
This class processes one step within the whole time sequence input, whereas
`tf.keras.layer.LayernormSimpleRNN` processes the whole sequence.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the linear
transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape of `[batch, feature]`.
states: A 2D tensor with shape of `[batch, units]`, which is the state from
the previous time step. For timestep 0, the initial state provided by user
will be feed to cell.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True))
output = rnn(inputs) # The output has shape `[32, 4]`.
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True),
return_sequences=True,
return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = rnn(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
**kwargs):
self._enable_caching_device = kwargs.pop('enable_caching_device', False)
super(LayernormSimpleRNNCell, self).__init__(**kwargs) # TMP(!)
self.units = units
self.activation = activations.get(activation)
self.use_bias = False if use_layernorm else use_bias # NEW(!)
self.use_layernorm = use_layernorm # NEW(!)
self.layernorm_epsilon = layernorm_epsilon # NEW(!)
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.gamma_initializer = initializers.get(gamma_initializer) # NEW(!)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer) # NEW(!)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.gamma_constraint = constraints.get(gamma_constraint) # NEW(!)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_size = self.units
self.output_size = self.units
#@tf_utils.shape_type_conversion
def build(self, input_shape):
#default_caching_device = _caching_device(self)
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
#caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
#caching_device=default_caching_device)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.units,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
#caching_device=default_caching_device)
else:
self.bias = None
if self.use_layernorm: # vvv NEW(!)
self.layernorm = LayerNormalization(
axis=-1, center=True, scale=True, trainable=True,
name='layernorm',
epsilon=self.layernorm_epsilon,
beta_initializer=self.bias_initializer,
gamma_initializer=self.gamma_initializer,
beta_regularizer=self.bias_regularizer,
gamma_regularizer=self.gamma_regularizer,
beta_constraint=self.bias_constraint,
gamma_constraint=self.gamma_constraint)
else:
self.layernorm = None # ^^^ NEW(!)
self.built = True
def call(self, inputs, states, training=None):
prev_output = states[0]
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
prev_output, training)
if dp_mask is not None:
h = K.dot(inputs * dp_mask, self.kernel)
else:
h = K.dot(inputs, self.kernel)
if self.bias is not None:
h = K.bias_add(h, self.bias)
if rec_dp_mask is not None:
prev_output = prev_output * rec_dp_mask
output = h + K.dot(prev_output, self.recurrent_kernel)
if self.layernorm is not None: # NEW(!)
output = self.layernorm(output) # NEW(!)
if self.activation is not None:
output = self.activation(output)
return output, [output]
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'use_bias':
self.use_bias,
'use_layernorm':
self.use_layernorm, # NEW(!)
'layernorm_epsilon':
self.layernorm_epsilon, # NEW(!)
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'gamma_initializer':
initializers.serialize(self.gamma_initializer), # NEW(!)
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'gamma_regularizer':
regularizers.serialize(self.gamma_regularizer), # NEW(!)
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint),
'gamma_constraint':
constraints.serialize(self.gamma_constraint), # NEW(!)
'dropout':
self.dropout,
'recurrent_dropout':
self.recurrent_dropout
}
base_config = super(LayernormSimpleRNNCell, self).get_config() # TMP(!)
return dict(list(base_config.items()) + list(config.items()))
#@keras_export('keras.experimental.LayernormSimpleRNN')
class LayernormSimpleRNN(SimpleRNN):
"""Fully-connected RNN where the output is to be fed back to input.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNN
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the inputs.
Default: 0.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the
recurrent state. Default: 0.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state
in addition to the output. Default: `False`
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
mask: Binary tensor of shape `[batch, timesteps]` indicating whether
a given timestep should be masked.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
model = tf.keras.layers.LayernormSimpleRNN(4, use_layernorm=True)
output = model(inputs) # The output has shape `[32, 4]`.
model = tf.keras.layers.LayernormSimpleRNN(
4, use_layernorm=True, return_sequences=True, return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = model(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs):
if 'implementation' in kwargs:
kwargs.pop('implementation')
logging.warning('The `implementation` argument '
'in `LayernormSimpleRNN` has been deprecated. ' # TMP(!)
'Please remove it from your layer call.')
cell = LayernormSimpleRNNCell( # TMP(!)
units,
activation=activation,
use_bias=use_bias,
use_layernorm=use_layernorm, # NEW(!)
layernorm_epsilon=layernorm_epsilon, # NEW(!)
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
gamma_initializer=gamma_initializer, # NEW(!)
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
gamma_regularizer=gamma_regularizer, # NEW(!)
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
gamma_constraint=gamma_constraint, # NEW(!)
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True))
super(SimpleRNN, self).__init__( # TMP(!)
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
#def call(self, inputs, mask=None, training=None, initial_state=None):
#self._maybe_reset_cell_dropout_mask(self.cell)
# return super(LayernormSimpleRNN, self).call( # TMP(!)
# inputs, mask=mask, training=training, initial_state=initial_state)
@property
def use_layernorm(self):
return self.cell.use_layernorm # NEW(!)
@property
def layernorm_epsilon(self):
return self.cell.layernorm_epsilon # NEW(!)
@property
def gamma_initializer(self):
return self.cell.gamma_initializer # NEW(!)
@property
def gamma_regularizer(self):
return self.cell.gamma_regularizer # NEW(!)
@property
def gamma_constraint(self):
return self.cell.gamma_constraint # NEW(!)
def get_config(self):
base_config = super(SimpleRNN, self).get_config()
del base_config['cell']
cell_config = self.cell.get_config()
return dict(list(base_config.items()) + list(cell_config.items()))
# set model parameters
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
# create datasets
x = np.random.random((num_samples, timesteps, embedding_dim))
y = np.random.random((num_samples, units))
# modeling
model = Sequential([
LayernormSimpleRNN(
units, use_layernorm=True,
input_shape=(None, embedding_dim))
])
model.compile('rmsprop', 'mse')
# training
model.fit(x, y, verbose=1)
print(model.summary())
cfg = model.get_config()
print(cfg)
# https://github.com/tensorflow/tensorflow/pull/35469#issuecomment-570977586
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin, SimpleRNN, SimpleRNNCell
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.layers.recurrent import _generate_zero_filled_state_for_cell
# _maybe_reset_cell_dropout_mask, _caching_device
from tensorflow.keras.layers import LayerNormalization # NEW(!)
from tensorflow.keras.models import Sequential
import numpy as np
#@keras_export('keras.experimental.LayernormSimpleRNNCell')
class LayernormSimpleRNNCell(SimpleRNNCell): # Simple inheritance
"""Cell class for LayernormSimpleRNN.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNNCell
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
This class processes one step within the whole time sequence input, whereas
`tf.keras.layer.LayernormSimpleRNN` processes the whole sequence.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the linear
transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape of `[batch, feature]`.
states: A 2D tensor with shape of `[batch, units]`, which is the state from
the previous time step. For timestep 0, the initial state provided by user
will be feed to cell.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True))
output = rnn(inputs) # The output has shape `[32, 4]`.
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True),
return_sequences=True,
return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = rnn(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
**kwargs):
self._enable_caching_device = kwargs.pop('enable_caching_device', False)
self.use_layernorm = use_layernorm # NEW(!)
super(LayernormSimpleRNNCell, self).__init__(
units,
activation=activation,
use_bias=False if use_layernorm else use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=None if use_layernorm else bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=None if use_layernorm else bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=None if use_layernorm else bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True))
if self.use_layernorm: # vvv NEW(!)
self.layernorm = LayerNormalization(
axis=-1,
epsilon=layernorm_epsilon,
center=True,
scale=True,
beta_initializer=bias_initializer,
gamma_initializer=gamma_initializer,
beta_regularizer=bias_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=bias_constraint,
gamma_constraint=gamma_constraint,
trainable=kwargs.get('trainable', True),
name='layernorm')
else:
self.layernorm = None # ^^^ NEW(!)
#@tf_utils.shape_type_conversion
def build(self, input_shape):
#default_caching_device = _caching_device(self)
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
#caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
#caching_device=default_caching_device)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.units,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
#caching_device=default_caching_device)
else:
self.bias = None
self.built = True
def call(self, inputs, states, training=None):
prev_output = states[0]
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
prev_output, training)
if dp_mask is not None:
h = K.dot(inputs * dp_mask, self.kernel)
else:
h = K.dot(inputs, self.kernel)
if self.bias is not None:
h = K.bias_add(h, self.bias)
if rec_dp_mask is not None:
prev_output = prev_output * rec_dp_mask
output = h + K.dot(prev_output, self.recurrent_kernel)
if self.layernorm is not None: # NEW(!)
output = self.layernorm(output) # NEW(!)
if self.activation is not None:
output = self.activation(output)
return output, [output]
# def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
# return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
def get_config(self):
config = {
'use_layernorm':
self.use_layernorm, # NEW(!)
}
cell_config = super(LayernormSimpleRNNCell, self).get_config() # from SimpleRNNCell
if self.use_layernorm:
ln_config = self.layernorm.get_config()
ln_config['bias_initializer'] = ln_config.pop("beta_initializer")
ln_config['bias_regularizer'] = ln_config.pop("beta_regularizer")
ln_config['bias_constraint'] = ln_config.pop("beta_constraint")
ln_config['layernorm_epsilon'] = ln_config.pop("epsilon")
del ln_config['axis']
del ln_config['center']
del ln_config['scale']
else:
ln_config = {}
return dict(list(config.items()) + list(cell_config.items()) + list(ln_config.items()))
#@keras_export('keras.experimental.LayernormSimpleRNN')
class LayernormSimpleRNN(SimpleRNN):
"""Fully-connected RNN where the output is to be fed back to input.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNN
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the inputs.
Default: 0.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the
recurrent state. Default: 0.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state
in addition to the output. Default: `False`
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
mask: Binary tensor of shape `[batch, timesteps]` indicating whether
a given timestep should be masked.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
model = tf.keras.layers.LayernormSimpleRNN(4, use_layernorm=True)
output = model(inputs) # The output has shape `[32, 4]`.
model = tf.keras.layers.LayernormSimpleRNN(
4, use_layernorm=True, return_sequences=True, return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = model(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs):
if 'implementation' in kwargs:
kwargs.pop('implementation')
logging.warning('The `implementation` argument '
'in `LayernormSimpleRNN` has been deprecated. ' # TMP(!)
'Please remove it from your layer call.')
cell = LayernormSimpleRNNCell( # TMP(!)
units,
activation=activation,
use_bias=use_bias,
use_layernorm=use_layernorm, # NEW(!)
layernorm_epsilon=layernorm_epsilon, # NEW(!)
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
gamma_initializer=gamma_initializer, # NEW(!)
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
gamma_regularizer=gamma_regularizer, # NEW(!)
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
gamma_constraint=gamma_constraint, # NEW(!)
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True))
super(SimpleRNN, self).__init__( # init RNN class
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
#def call(self, inputs, mask=None, training=None, initial_state=None):
#self._maybe_reset_cell_dropout_mask(self.cell)
# return super(LayernormSimpleRNN, self).call( # TMP(!)
# inputs, mask=mask, training=training, initial_state=initial_state)
@property
def use_layernorm(self):
return self.cell.use_layernorm # NEW(!)
@property
def layernorm_epsilon(self):
return self.cell.layernorm_epsilon # NEW(!)
@property
def gamma_initializer(self):
return self.cell.gamma_initializer # NEW(!)
@property
def gamma_regularizer(self):
return self.cell.gamma_regularizer # NEW(!)
@property
def gamma_constraint(self):
return self.cell.gamma_constraint # NEW(!)
def get_config(self):
rnn_config = super(SimpleRNN, self).get_config() # from RNN class
del rnn_config['cell']
cell_config = self.cell.get_config()
return dict(list(rnn_config.items()) + list(cell_config.items()))
# set model parameters
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
# create datasets
x = np.random.random((num_samples, timesteps, embedding_dim))
y = np.random.random((num_samples, units))
# modeling
model = Sequential([
LayernormSimpleRNN(
units, use_layernorm=True,
input_shape=(None, embedding_dim))
])
model.compile('rmsprop', 'mse')
# training
model.fit(x, y, verbose=1)
print(model.summary())
cfg = model.get_config()
print(cfg)
# https://github.com/tensorflow/tensorflow/pull/35469#issuecomment-570977586
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin, SimpleRNN, SimpleRNNCell
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.layers.recurrent import _generate_zero_filled_state_for_cell
# _maybe_reset_cell_dropout_mask, _caching_device
from tensorflow.keras.layers import LayerNormalization # NEW(!)
from tensorflow.keras.models import Sequential
import numpy as np
#@keras_export('keras.experimental.LayernormSimpleRNNCell')
class LayernormSimpleRNNCell(SimpleRNNCell): # Simple inheritance
"""Cell class for LayernormSimpleRNN.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNNCell
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
This class processes one step within the whole time sequence input, whereas
`tf.keras.layer.LayernormSimpleRNN` processes the whole sequence.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the linear
transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape of `[batch, feature]`.
states: A 2D tensor with shape of `[batch, units]`, which is the state from
the previous time step. For timestep 0, the initial state provided by user
will be feed to cell.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True))
output = rnn(inputs) # The output has shape `[32, 4]`.
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True),
return_sequences=True,
return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = rnn(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
**kwargs):
self._enable_caching_device = kwargs.pop('enable_caching_device', False)
self.use_layernorm = use_layernorm # NEW(!)
SimpleRNNCell.__init__(
self,
units,
activation=activation,
use_bias=False if use_layernorm else use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=None if use_layernorm else bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=None if use_layernorm else bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=None if use_layernorm else bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True))
if self.use_layernorm: # vvv NEW(!)
self.layernorm = LayerNormalization(
axis=-1,
epsilon=layernorm_epsilon,
center=True,
scale=True,
beta_initializer=bias_initializer,
gamma_initializer=gamma_initializer,
beta_regularizer=bias_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=bias_constraint,
gamma_constraint=gamma_constraint,
trainable=kwargs.get('trainable', True),
name='layernorm')
else:
self.layernorm = None # ^^^ NEW(!)
#@tf_utils.shape_type_conversion
def build(self, input_shape):
#default_caching_device = _caching_device(self)
SimpleRNNCell.build(self, input_shape)
def call(self, inputs, states, training=None):
prev_output = states[0]
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
prev_output, training)
if dp_mask is not None:
h = K.dot(inputs * dp_mask, self.kernel)
else:
h = K.dot(inputs, self.kernel)
if self.bias is not None:
h = K.bias_add(h, self.bias)
if rec_dp_mask is not None:
prev_output = prev_output * rec_dp_mask
output = h + K.dot(prev_output, self.recurrent_kernel)
if self.layernorm is not None: # NEW(!)
output = self.layernorm(output) # NEW(!)
if self.activation is not None:
output = self.activation(output)
return output, [output]
def get_config(self):
config = {
'use_layernorm':
self.use_layernorm, # NEW(!)
}
cell_config = SimpleRNNCell.get_config(self)
if self.use_layernorm:
ln_config = self.layernorm.get_config()
ln_config['bias_initializer'] = ln_config.pop("beta_initializer")
ln_config['bias_regularizer'] = ln_config.pop("beta_regularizer")
ln_config['bias_constraint'] = ln_config.pop("beta_constraint")
ln_config['layernorm_epsilon'] = ln_config.pop("epsilon")
del ln_config['axis']
del ln_config['center']
del ln_config['scale']
else:
ln_config = {}
return dict(list(config.items()) + list(cell_config.items()) + list(ln_config.items()))
#@keras_export('keras.experimental.LayernormSimpleRNN')
class LayernormSimpleRNN(SimpleRNN):
"""Fully-connected RNN where the output is to be fed back to input.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNN
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the inputs.
Default: 0.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the
recurrent state. Default: 0.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state
in addition to the output. Default: `False`
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
mask: Binary tensor of shape `[batch, timesteps]` indicating whether
a given timestep should be masked.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
model = tf.keras.layers.LayernormSimpleRNN(4, use_layernorm=True)
output = model(inputs) # The output has shape `[32, 4]`.
model = tf.keras.layers.LayernormSimpleRNN(
4, use_layernorm=True, return_sequences=True, return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = model(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs):
if 'implementation' in kwargs:
kwargs.pop('implementation')
logging.warning('The `implementation` argument '
'in `LayernormSimpleRNN` has been deprecated. ' # TMP(!)
'Please remove it from your layer call.')
cell = LayernormSimpleRNNCell( # TMP(!)
units,
activation=activation,
use_bias=use_bias,
use_layernorm=use_layernorm, # NEW(!)
layernorm_epsilon=layernorm_epsilon, # NEW(!)
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
gamma_initializer=gamma_initializer, # NEW(!)
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
gamma_regularizer=gamma_regularizer, # NEW(!)
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
gamma_constraint=gamma_constraint, # NEW(!)
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True))
super(SimpleRNN, self).__init__( # init RNN class
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
#def call(self, inputs, mask=None, training=None, initial_state=None):
#self._maybe_reset_cell_dropout_mask(self.cell)
# return super(LayernormSimpleRNN, self).call( # TMP(!)
# inputs, mask=mask, training=training, initial_state=initial_state)
@property
def use_layernorm(self):
return self.cell.use_layernorm # NEW(!)
@property
def layernorm_epsilon(self):
return self.cell.layernorm_epsilon # NEW(!)
@property
def gamma_initializer(self):
return self.cell.gamma_initializer # NEW(!)
@property
def gamma_regularizer(self):
return self.cell.gamma_regularizer # NEW(!)
@property
def gamma_constraint(self):
return self.cell.gamma_constraint # NEW(!)
def get_config(self):
rnn_config = super(SimpleRNN, self).get_config() # from RNN class
del rnn_config['cell']
cell_config = self.cell.get_config()
return dict(list(rnn_config.items()) + list(cell_config.items()))
# set model parameters
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
# create datasets
x = np.random.random((num_samples, timesteps, embedding_dim))
y = np.random.random((num_samples, units))
# modeling
model = Sequential([
LayernormSimpleRNN(
units, use_layernorm=True,
input_shape=(None, embedding_dim))
])
model.compile('rmsprop', 'mse')
# training
model.fit(x, y, verbose=1)
print(model.summary())
cfg = model.get_config()
print(cfg)
# https://github.com/tensorflow/tensorflow/pull/35469#issuecomment-570977586
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin, SimpleRNN, SimpleRNNCell
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.layers.recurrent import _generate_zero_filled_state_for_cell
# _maybe_reset_cell_dropout_mask, _caching_device
from tensorflow.keras.layers import LayerNormalization # NEW(!)
from tensorflow.keras.models import Sequential
import numpy as np
#@keras_export('keras.experimental.LayernormSimpleRNNCell')
class LayernormSimpleRNNCell(SimpleRNNCell, LayerNormalization):
"""Cell class for LayernormSimpleRNN.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNNCell
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
This class processes one step within the whole time sequence input, whereas
`tf.keras.layer.LayernormSimpleRNN` processes the whole sequence.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent
state. Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the
linear transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape of `[batch, feature]`.
states: A 2D tensor with shape of `[batch, units]`, which is the state
from the previous time step. For timestep 0, the initial state provided
by the user will be feed to cell.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True))
output = rnn(inputs) # The output has shape `[32, 4]`.
rnn = tf.keras.layers.RNN(
tf.keras.layers.LayernormSimpleRNNCell(4, use_layernorm=True),
return_sequences=True,
return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = rnn(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
**kwargs):
self.use_layernorm = use_layernorm
SimpleRNNCell.__init__(
self,
units,
activation=activation,
use_bias=False if use_layernorm else use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=None if use_layernorm else bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=None if use_layernorm else bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=None if use_layernorm else bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True))
if use_layernorm:
LayerNormalization.__init__(
self,
axis=-1,
epsilon=layernorm_epsilon,
center=True,
scale=True,
beta_initializer=bias_initializer,
gamma_initializer=gamma_initializer,
beta_regularizer=bias_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=bias_constraint,
gamma_constraint=gamma_constraint,
trainable=kwargs.get('trainable', True))
#@tf_utils.shape_type_conversion
def build(self, input_shape):
#default_caching_device = _caching_device(self)
SimpleRNNCell.build(self, input_shape)
if self.use_layernorm:
LayerNormalization.build(self, (None, self.units))
def call(self, inputs, states, training=None):
prev_output = states[0]
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
prev_output, training)
if dp_mask is not None:
h = K.dot(inputs * dp_mask, self.kernel)
else:
h = K.dot(inputs, self.kernel)
if self.bias is not None:
h = K.bias_add(h, self.bias)
if rec_dp_mask is not None:
prev_output = prev_output * rec_dp_mask
output = h + K.dot(prev_output, self.recurrent_kernel)
if self.use_layernorm: # NEW(!)
output = LayerNormalization.call(self, output) # NEW(!)
if self.activation is not None:
output = self.activation(output)
return output, [output]
def get_config(self):
config = {
'use_layernorm':
self.use_layernorm
}
cell_config = SimpleRNNCell.get_config(self)
if self.use_layernorm:
ln_config = LayerNormalization.get_config(self)
ln_config['bias_initializer'] = ln_config.pop("beta_initializer")
ln_config['bias_regularizer'] = ln_config.pop("beta_regularizer")
ln_config['bias_constraint'] = ln_config.pop("beta_constraint")
ln_config['layernorm_epsilon'] = ln_config.pop("epsilon")
del ln_config['axis']
del ln_config['center']
del ln_config['scale']
else:
ln_config = {}
return {**config, **cell_config, **ln_config}
#@keras_export('keras.experimental.LayernormSimpleRNN')
class LayernormSimpleRNN(SimpleRNN):
"""Fully-connected RNN where the output is to be fed back to input.
Motivation:
- Drop-In Replacement for keras.layers.SimpleRNN
- demonstrate how to add LayerNormalization to all RNNs as option
- see Ba et al. (2016), and tf.keras.layers.LayerNormalization
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
Arguments:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
use_layernorm: Boolean, (default `False`), whether layer uses layer
normalization instead of a bias vector.
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
to avoid dividing by zero.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent
state. Default: `orthogonal`.
bias_initializer: Initializer for the bias vector (`use_bias=True`) or
for the beta vector in layer normalization (`use_layernorm=True`).
Default: `zeros`.
gamma_initializer: Initializer for the gamma vector of the layer
normalization layer (`use_layernorm=True`). Default: `ones`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_regularizer: Regularizer function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector
(`use_bias=True`) or for the beta vector of the layer normalization
layer (`use_layernorm=True`). Default: `None`.
gamma_constraint: Constraint function applied to the gamma vector
of the layer normalization layer (`use_layernorm=True`).
Default: `None`.
dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the
inputs. Default: 0.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the
recurrent state. Default: 0.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state
in addition to the output. Default: `False`
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
mask: Binary tensor of shape `[batch, timesteps]` indicating whether
a given timestep should be masked.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
model = tf.keras.layers.LayernormSimpleRNN(4, use_layernorm=True)
output = model(inputs) # The output has shape `[32, 4]`.
model = tf.keras.layers.LayernormSimpleRNN(
4, use_layernorm=True, return_sequences=True, return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = model(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
use_layernorm=False, # NEW(!)
layernorm_epsilon=1e-05, # NEW(!)
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
gamma_initializer='ones', # NEW(!)
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
gamma_regularizer=None, # NEW(!)
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
gamma_constraint=None, # NEW(!)
dropout=0.,
recurrent_dropout=0.,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs):
# 'implementation' warning was never relevant for LayernormSimpleRNN
cell = LayernormSimpleRNNCell( # TMP(!)
units,
activation=activation,
use_bias=use_bias,
use_layernorm=use_layernorm, # NEW(!)
layernorm_epsilon=layernorm_epsilon, # NEW(!)
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
gamma_initializer=gamma_initializer, # NEW(!)
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
gamma_regularizer=gamma_regularizer, # NEW(!)
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
gamma_constraint=gamma_constraint, # NEW(!)
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True))
super(SimpleRNN, self).__init__( # init RNN class
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
# use SimpleRNN's call() method
@property
def use_layernorm(self):
return self.cell.use_layernorm # NEW(!)
@property
def layernorm_epsilon(self):
return self.cell.layernorm_epsilon # NEW(!)
@property
def gamma_initializer(self):
return self.cell.gamma_initializer # NEW(!)
@property
def gamma_regularizer(self):
return self.cell.gamma_regularizer # NEW(!)
@property
def gamma_constraint(self):
return self.cell.gamma_constraint # NEW(!)
def get_config(self):
base_config = super(SimpleRNN, self).get_config() # from RNN class
del base_config['cell']
cell_config = self.cell.get_config()
return dict(list(base_config.items()) + list(cell_config.items()))
# set model parameters
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
# create datasets
x = np.random.random((num_samples, timesteps, embedding_dim))
y = np.random.random((num_samples, units))
# modeling
model = Sequential([
LayernormSimpleRNN(
units, use_layernorm=True,
input_shape=(None, embedding_dim))
])
model.compile('rmsprop', 'mse')
# training
model.fit(x, y, verbose=1)
print(model.summary())
cfg = model.get_config()
print(cfg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment