Skip to content

Instantly share code, notes, and snippets.

@ThomasDelteil
Last active November 22, 2018 02:55
Show Gist options
  • Save ThomasDelteil/4274c1946e958615df915d5d30908c19 to your computer and use it in GitHub Desktop.
Save ThomasDelteil/4274c1946e958615df915d5d30908c19 to your computer and use it in GitHub Desktop.
TextDenoising
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Encoder and decoder usded in sequence-to-sequence learning."""
__all__ = ['TransformerEncoder', 'TransformerDecoder',
'get_transformer_encoder_decoder']
import math
import os
import warnings
from functools import partial
from gluonnlp.model import AttentionCell, MLPAttentionCell, DotProductAttentionCell, MultiHeadAttentionCell
import gluonnlp as nlp
import mxnet as mx
from mxnet.gluon import rnn
from mxnet.gluon.block import Block
from mxnet import cpu, gluon
from mxnet.gluon import nn
from mxnet.gluon.block import HybridBlock
import numpy as np
def _list_bcast_where(F, mask, new_val_l, old_val_l):
"""Broadcast where. Implements out[i] = new_val[i] * mask + old_val[i] * (1 - mask)
Parameters
----------
F : symbol or ndarray
mask : Symbol or NDArray
new_val_l : list of Symbols or list of NDArrays
old_val_l : list of Symbols or list of NDArrays
Returns
-------
out_l : list of Symbols or list of NDArrays
"""
return [F.broadcast_mul(new_val, mask) + F.broadcast_mul(old_val, 1 - mask)
for new_val, old_val in zip(new_val_l, old_val_l)]
def _get_cell_type(cell_type):
"""Get the object type of the cell by parsing the input
Parameters
----------
cell_type : str or type
Returns
-------
cell_constructor: type
The constructor of the RNNCell
"""
if isinstance(cell_type, str):
if cell_type == 'lstm':
return rnn.LSTMCell
elif cell_type == 'gru':
return rnn.GRUCell
elif cell_type == 'relu_rnn':
return partial(rnn.RNNCell, activation='relu')
elif cell_type == 'tanh_rnn':
return partial(rnn.RNNCell, activation='tanh')
else:
raise NotImplementedError
else:
return cell_type
def _get_attention_cell(attention_cell, units=None,
scaled=True, num_heads=None,
use_bias=False, dropout=0.0):
"""
Parameters
----------
attention_cell : AttentionCell or str
units : int or None
Returns
-------
attention_cell : AttentionCell
"""
if isinstance(attention_cell, str):
if attention_cell == 'scaled_luong':
return DotProductAttentionCell(units=units, scaled=True, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=True)
elif attention_cell == 'scaled_dot':
return DotProductAttentionCell(units=units, scaled=True, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=False)
elif attention_cell == 'dot':
return DotProductAttentionCell(units=units, scaled=False, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=False)
elif attention_cell == 'cosine':
return DotProductAttentionCell(units=units, scaled=False, use_bias=use_bias,
dropout=dropout, normalized=True)
elif attention_cell == 'mlp':
return MLPAttentionCell(units=units, normalized=False)
elif attention_cell == 'normed_mlp':
return MLPAttentionCell(units=units, normalized=True)
elif attention_cell == 'multi_head':
base_cell = DotProductAttentionCell(scaled=scaled, dropout=dropout)
return MultiHeadAttentionCell(base_cell=base_cell, query_units=units, use_bias=use_bias,
key_units=units, value_units=units, num_heads=num_heads)
else:
raise NotImplementedError
else:
assert isinstance(attention_cell, AttentionCell),\
'attention_cell must be either string or AttentionCell. Received attention_cell={}'\
.format(attention_cell)
return attention_cell
def _nested_sequence_last(data, valid_length):
"""
Parameters
----------
data : nested container of NDArrays/Symbols
The input data. Each element will have shape (batch_size, ...)
valid_length : NDArray or Symbol
Valid length of the sequences. Shape (batch_size,)
Returns
-------
data_last: nested container of NDArrays/Symbols
The last valid element in the sequence.
"""
assert isinstance(data, list)
if isinstance(data[0], (mx.sym.Symbol, mx.nd.NDArray)):
F = mx.sym if isinstance(data[0], mx.sym.Symbol) else mx.ndarray
return F.SequenceLast(F.stack(*data, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
elif isinstance(data[0], list):
ret = []
for i in range(len(data[0])):
ret.append(_nested_sequence_last([ele[i] for ele in data], valid_length))
return ret
else:
raise NotImplementedError
class Seq2SeqEncoder(Block):
r"""Base class of the encoders in sequence to sequence learning models.
"""
def __call__(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ
"""Encode the input sequence.
Parameters
----------
inputs : NDArray
The input sequence, Shape (batch_size, length, C_in).
valid_length : NDArray or None, default None
The valid length of the input sequence, Shape (batch_size,). This is used when the
input sequences are padded. If set to None, all elements in the sequence are used.
states : list of NDArrays or None, default None
List that contains the initial states of the encoder.
Returns
-------
outputs : list
Outputs of the encoder.
"""
return super(Seq2SeqEncoder, self).__call__(inputs, valid_length, states)
def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ
raise NotImplementedError
class Seq2SeqDecoder(Block):
r"""Base class of the decoders in sequence to sequence learning models.
In the forward function, it generates the one-step-ahead decoding output.
"""
def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
r"""Generates the initial decoder states based on the encoder outputs.
Parameters
----------
encoder_outputs : list of NDArrays
encoder_valid_length : NDArray or None
Returns
-------
decoder_states : list
"""
raise NotImplementedError
def decode_seq(self, inputs, states, valid_length=None):
r"""Given the inputs and the context computed by the encoder,
generate the new states. This is usually used in the training phase where we set the inputs
to be the target sequence.
Parameters
----------
inputs : NDArray
The input embeddings. Shape (batch_size, length, C_in)
states : list
The initial states of the decoder.
valid_length : NDArray or None
valid length of the inputs. Shape (batch_size,)
Returns
-------
output : NDArray
The output of the decoder. Shape is (batch_size, length, C_out)
states: list
The new states of the decoder
additional_outputs : list
Additional outputs of the decoder, e.g, the attention weights
"""
raise NotImplementedError
def __call__(self, step_input, states): #pylint: disable=arguments-differ
r"""One-step decoding of the input
Parameters
----------
step_input : NDArray
Shape (batch_size, C_in)
states : list
The previous states of the decoder
Returns
-------
step_output : NDArray
Shape (batch_size, C_out)
states : list
step_additional_outputs : list
Additional outputs of the step, e.g, the attention weights
"""
return super(Seq2SeqDecoder, self).__call__(step_input, states)
def forward(self, step_input, states): #pylint: disable=arguments-differ
raise NotImplementedError
def _position_encoding_init(max_length, dim):
""" Init the sinusoid position encoding table """
position_enc = np.arange(max_length).reshape((-1, 1)) \
/ (np.power(10000, (2. / dim) * np.arange(dim).reshape((1, -1))))
# Apply the cosine to even columns and sin to odds.
position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1
return position_enc
class PositionwiseFFN(HybridBlock):
"""Structure of the Positionwise Feed-Forward Neural Network.
Parameters
----------
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
dropout : float
use_residual : bool
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
activation : str, default 'relu'
Activation function
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, units=512, hidden_size=2048, dropout=0.0, use_residual=True,
weight_initializer=None, bias_initializer='zeros', activation='relu',
prefix=None, params=None):
super(PositionwiseFFN, self).__init__(prefix=prefix, params=params)
self._hidden_size = hidden_size
self._units = units
self._use_residual = use_residual
with self.name_scope():
self.ffn_1 = nn.Dense(units=hidden_size, flatten=False,
activation=activation,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='ffn_1_')
self.ffn_2 = nn.Dense(units=units, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='ffn_2_')
self.dropout_layer = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm()
def hybrid_forward(self, F, inputs): # pylint: disable=arguments-differ
# pylint: disable=unused-argument
"""Position-wise encoding of the inputs.
Parameters
----------
inputs : Symbol or NDArray
Input sequence. Shape (batch_size, length, C_in)
Returns
-------
outputs : Symbol or NDArray
Shape (batch_size, length, C_out)
"""
outputs = self.ffn_1(inputs)
outputs = self.ffn_2(outputs)
outputs = self.dropout_layer(outputs)
if self._use_residual:
outputs = outputs + inputs
outputs = self.layer_norm(outputs)
return outputs
class TransformerEncoderCell(HybridBlock):
"""Structure of the Transformer Encoder Cell.
Parameters
----------
attention_cell : AttentionCell or str, default 'multi_head'
Arguments of the attention cell.
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp'
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
num_heads : int
Number of heads in multi-head attention
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, attention_cell='multi_head', units=128,
hidden_size=512, num_heads=4, scaled=True,
dropout=0.0, use_residual=True, output_attention=False,
weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(TransformerEncoderCell, self).__init__(prefix=prefix, params=params)
self._units = units
self._num_heads = num_heads
self._dropout = dropout
self._use_residual = use_residual
self._output_attention = output_attention
with self.name_scope():
self.dropout_layer = nn.Dropout(dropout)
self.attention_cell = _get_attention_cell(attention_cell,
units=units,
num_heads=num_heads,
scaled=scaled,
dropout=dropout)
self.proj = nn.Dense(units=units, flatten=False, use_bias=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='proj_')
self.ffn = PositionwiseFFN(hidden_size=hidden_size, units=units,
use_residual=use_residual, dropout=dropout,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
self.layer_norm = nn.LayerNorm()
def hybrid_forward(self, F, inputs, mask=None): # pylint: disable=arguments-differ
# pylint: disable=unused-argument
"""Transformer Encoder Attention Cell.
Parameters
----------
inputs : Symbol or NDArray
Input sequence. Shape (batch_size, length, C_in)
mask : Symbol or NDArray or None
Mask for inputs. Shape (batch_size, length, length)
Returns
-------
encoder_cell_outputs: list
Outputs of the encoder cell. Contains:
- outputs of the transformer encoder cell. Shape (batch_size, length, C_out)
- additional_outputs of all the transformer encoder cell
"""
outputs, attention_weights =\
self.attention_cell(inputs, inputs, inputs, mask)
outputs = self.proj(outputs)
outputs = self.dropout_layer(outputs)
if self._use_residual:
outputs = outputs + inputs
outputs = self.layer_norm(outputs)
outputs = self.ffn(outputs)
additional_outputs = []
if self._output_attention:
additional_outputs.append(attention_weights)
return outputs, additional_outputs
class TransformerDecoderCell(HybridBlock):
"""Structure of the Transformer Decoder Cell.
Parameters
----------
attention_cell : AttentionCell or str, default 'multi_head'
Arguments of the attention cell.
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp'
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
num_heads : int
Number of heads in multi-head attention
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, attention_cell='multi_head', units=128,
hidden_size=512, num_heads=4, scaled=True,
dropout=0.0, use_residual=True, output_attention=False,
weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(TransformerDecoderCell, self).__init__(prefix=prefix, params=params)
self._units = units
self._num_heads = num_heads
self._dropout = dropout
self._use_residual = use_residual
self._output_attention = output_attention
self._scaled = scaled
with self.name_scope():
self.dropout_layer = nn.Dropout(dropout)
self.attention_cell_in = _get_attention_cell(attention_cell,
units=units,
num_heads=num_heads,
scaled=scaled,
dropout=dropout)
self.attention_cell_inter = _get_attention_cell(attention_cell,
units=units,
num_heads=num_heads,
scaled=scaled,
dropout=dropout)
self.proj_in = nn.Dense(units=units, flatten=False,
use_bias=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='proj_in_')
self.proj_inter = nn.Dense(units=units, flatten=False,
use_bias=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='proj_inter_')
self.ffn = PositionwiseFFN(hidden_size=hidden_size,
units=units,
use_residual=use_residual,
dropout=dropout,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
self.layer_norm_in = nn.LayerNorm()
self.layer_norm_inter = nn.LayerNorm()
def hybrid_forward(self, F, inputs, mem_value, mask=None, mem_mask=None): #pylint: disable=unused-argument
# pylint: disable=arguments-differ
"""Transformer Decoder Attention Cell.
Parameters
----------
inputs : Symbol or NDArray
Input sequence. Shape (batch_size, length, C_in)
mem_value : Symbol or NDArrays
Memory value, i.e. output of the encoder. Shape (batch_size, mem_length, C_in)
mask : Symbol or NDArray or None
Mask for inputs. Shape (batch_size, length, length)
mem_mask : Symbol or NDArray or None
Mask for mem_value. Shape (batch_size, length, mem_length)
Returns
-------
decoder_cell_outputs: list
Outputs of the decoder cell. Contains:
- outputs of the transformer decoder cell. Shape (batch_size, length, C_out)
- additional_outputs of all the transformer decoder cell
"""
outputs, attention_in_outputs =\
self.attention_cell_in(inputs, inputs, inputs, mask)
outputs = self.proj_in(outputs)
outputs = self.dropout_layer(outputs)
if self._use_residual:
outputs = outputs + inputs
outputs = self.layer_norm_in(outputs)
inputs = outputs
outputs, attention_inter_outputs = \
self.attention_cell_inter(inputs, mem_value, mem_value, mem_mask)
outputs = self.proj_inter(outputs)
outputs = self.dropout_layer(outputs)
if self._use_residual:
outputs = outputs + inputs
outputs = self.layer_norm_inter(outputs)
outputs = self.ffn(outputs)
additional_outputs = []
if self._output_attention:
additional_outputs.append(attention_in_outputs)
additional_outputs.append(attention_inter_outputs)
return outputs, additional_outputs
class TransformerEncoder(HybridBlock, Seq2SeqEncoder):
"""Structure of the Transformer Encoder.
Parameters
----------
attention_cell : AttentionCell or str, default 'multi_head'
Arguments of the attention cell.
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp'
num_layers : int
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
max_length : int
Maximum length of the input sequence
num_heads : int
Number of heads in multi-head attention
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, attention_cell='multi_head', num_layers=2,
units=512, hidden_size=2048, max_length=50,
num_heads=4, scaled=True, dropout=0.0,
use_residual=True, output_attention=False,
weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(TransformerEncoder, self).__init__(prefix=prefix, params=params)
assert units % num_heads == 0,\
'In TransformerEncoder, The units should be divided exactly ' \
'by the number of heads. Received units={}, num_heads={}' \
.format(units, num_heads)
self._num_layers = num_layers
self._max_length = max_length
self._num_heads = num_heads
self._units = units
self._hidden_size = hidden_size
self._output_attention = output_attention
self._dropout = dropout
self._use_residual = use_residual
self._scaled = scaled
with self.name_scope():
self.dropout_layer = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm()
self.position_weight = self.params.get_constant('const',
_position_encoding_init(max_length,
units))
self.transformer_cells = nn.HybridSequential()
for i in range(num_layers):
self.transformer_cells.add(
TransformerEncoderCell(
units=units,
hidden_size=hidden_size,
num_heads=num_heads,
attention_cell=attention_cell,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dropout=dropout,
use_residual=use_residual,
scaled=scaled,
output_attention=output_attention,
prefix='transformer%d_' % i))
def __call__(self, inputs, states=None, valid_length=None): #pylint: disable=arguments-differ
"""Encoder the inputs given the states and valid sequence length.
Parameters
----------
inputs : NDArray
Input sequence. Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of initial states and masks
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)
Returns
-------
encoder_outputs: list
Outputs of the encoder. Contains:
- outputs of the transformer encoder. Shape (batch_size, length, C_out)
- additional_outputs of all the transformer encoder
"""
return super(TransformerEncoder, self).__call__(inputs, states, valid_length)
def forward(self, inputs, states=None, valid_length=None, steps=None): # pylint: disable=arguments-differ
"""
Parameters
----------
inputs : NDArray, Shape(batch_size, length, C_in)
states : list of NDArray
valid_length : NDArray
steps : NDArray
Stores value [0, 1, ..., length].
It is used for lookup in positional encoding matrix
Returns
-------
outputs : NDArray
The output of the encoder. Shape is (batch_size, length, C_out)
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, length) or
(batch_size, num_heads, length, length)
"""
length = inputs.shape[1]
if valid_length is not None:
mask = mx.nd.broadcast_lesser(
mx.nd.arange(length, ctx=valid_length.context).reshape((1, -1)),
valid_length.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=1), axis=1, size=length)
if states is None:
states = [mask]
else:
states.append(mask)
inputs = inputs * math.sqrt(inputs.shape[-1])
steps = mx.nd.arange(length, ctx=inputs.context)
if states is None:
states = [steps]
else:
states.append(steps)
if valid_length is not None:
step_output, additional_outputs =\
super(TransformerEncoder, self).forward(inputs, states, valid_length)
else:
step_output, additional_outputs =\
super(TransformerEncoder, self).forward(inputs, states)
return step_output, additional_outputs
def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_weight=None): # pylint: disable=arguments-differ
"""
Parameters
----------
inputs : NDArray or Symbol, Shape(batch_size, length, C_in)
states : list of NDArray or Symbol
valid_length : NDArray or Symbol
position_weight : NDArray or Symbol
Returns
-------
outputs : NDArray or Symbol
The output of the encoder. Shape is (batch_size, length, C_out)
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, length) or
(batch_size, num_heads, length, length)
"""
if states is not None:
steps = states[-1]
# Positional Encoding
inputs = F.broadcast_add(inputs, F.expand_dims(F.Embedding(steps, position_weight,
self._max_length,
self._units), axis=0))
inputs = self.dropout_layer(inputs)
inputs = self.layer_norm(inputs)
outputs = inputs
if valid_length is not None:
mask = states[-2]
else:
mask = None
additional_outputs = []
for cell in self.transformer_cells:
outputs, attention_weights = cell(inputs, mask)
inputs = outputs
if self._output_attention:
additional_outputs.append(attention_weights)
if valid_length is not None:
outputs = F.SequenceMask(outputs, sequence_length=valid_length,
use_sequence_length=True, axis=1)
return outputs, additional_outputs
class TransformerDecoder(HybridBlock, Seq2SeqDecoder):
"""Structure of the Transformer Decoder.
Parameters
----------
attention_cell : AttentionCell or str, default 'multi_head'
Arguments of the attention cell.
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp'
num_layers : int
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
max_length : int
Maximum length of the input sequence. This is used for constructing position encoding
num_heads : int
Number of heads in multi-head attention
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, attention_cell='multi_head', num_layers=2,
units=128, hidden_size=2048, max_length=50,
num_heads=4, scaled=True, dropout=0.0,
use_residual=True, output_attention=False,
weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(TransformerDecoder, self).__init__(prefix=prefix, params=params)
assert units % num_heads == 0, 'In TransformerDecoder, the units should be divided ' \
'exactly by the number of heads. Received units={}, ' \
'num_heads={}'.format(units, num_heads)
self._num_layers = num_layers
self._units = units
self._hidden_size = hidden_size
self._num_states = num_heads
self._max_length = max_length
self._dropout = dropout
self._use_residual = use_residual
self._output_attention = output_attention
self._scaled = scaled
with self.name_scope():
self.dropout_layer = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm()
self.position_weight = self.params.get_constant('const',
_position_encoding_init(max_length,
units))
self.transformer_cells = nn.HybridSequential()
for i in range(num_layers):
self.transformer_cells.add(
TransformerDecoderCell(
units=units,
hidden_size=hidden_size,
num_heads=num_heads,
attention_cell=attention_cell,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
output_attention=output_attention,
prefix='transformer%d_' % i))
def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
"""Initialize the state from the encoder outputs.
Parameters
----------
encoder_outputs : list
encoder_valid_length : NDArray or None
Returns
-------
decoder_states : list
The decoder states, includes:
- mem_value : NDArray
- mem_masks : NDArray, optional
"""
mem_value = encoder_outputs
decoder_states = [mem_value]
mem_length = mem_value.shape[1]
if encoder_valid_length is not None:
mem_masks = mx.nd.broadcast_lesser(
mx.nd.arange(mem_length, ctx=encoder_valid_length.context).reshape((1, -1)),
encoder_valid_length.reshape((-1, 1)))
decoder_states.append(mem_masks)
self._encoder_valid_length = encoder_valid_length
return decoder_states
def decode_seq(self, inputs, states, valid_length=None):
"""Decode the decoder inputs. This function is only used for training.
Parameters
----------
inputs : NDArray, Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of decoder states
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)
Returns
-------
output : NDArray, Shape (batch_size, length, C_out)
states : list
The decoder states, includes:
- mem_value : NDArray
- mem_masks : NDArray, optional
additional_outputs : list of list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
batch_size = inputs.shape[0]
length = inputs.shape[1]
length_array = mx.nd.arange(length, ctx=inputs.context)
mask = mx.nd.broadcast_lesser_equal(
length_array.reshape((1, -1)),
length_array.reshape((-1, 1)))
if valid_length is not None:
batch_mask = mx.nd.broadcast_lesser(
mx.nd.arange(length, ctx=valid_length.context).reshape((1, -1)),
valid_length.reshape((-1, 1)))
mask = mx.nd.broadcast_mul(mx.nd.expand_dims(batch_mask, -1),
mx.nd.expand_dims(mask, 0))
else:
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=batch_size)
states = [None] + states
output, states, additional_outputs = self.forward(inputs, states, mask)
states = states[1:]
if valid_length is not None:
output = mx.nd.SequenceMask(output,
sequence_length=valid_length,
use_sequence_length=True,
axis=1)
return output, states, additional_outputs
def __call__(self, step_input, states): #pylint: disable=arguments-differ
"""One-step-ahead decoding of the Transformer decoder.
Parameters
----------
step_input : NDArray
states : list of NDArray
Returns
-------
step_output : NDArray
The output of the decoder.
In the train mode, Shape is (batch_size, length, C_out)
In the test mode, Shape is (batch_size, C_out)
new_states: list
Includes
- last_embeds : NDArray or None
It is only given during testing
- mem_value : NDArray
- mem_masks : NDArray, optional
step_additional_outputs : list of list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
return super(TransformerDecoder, self).__call__(step_input, states)
def forward(self, step_input, states, mask=None): #pylint: disable=arguments-differ, missing-docstring
input_shape = step_input.shape
mem_mask = None
# If it is in testing, transform input tensor to a tensor with shape NTC
# Otherwise remove the None in states.
if len(input_shape) == 2:
if self._encoder_valid_length is not None:
has_last_embeds = len(states) == 3
else:
has_last_embeds = len(states) == 2
if has_last_embeds:
last_embeds = states[0]
step_input = mx.nd.concat(last_embeds,
mx.nd.expand_dims(step_input, axis=1),
dim=1)
states = states[1:]
else:
step_input = mx.nd.expand_dims(step_input, axis=1)
elif states[0] is None:
states = states[1:]
has_mem_mask = (len(states) == 2)
if has_mem_mask:
_, mem_mask = states
augmented_mem_mask = mx.nd.expand_dims(mem_mask, axis=1)\
.broadcast_axes(axis=1, size=step_input.shape[1])
states[-1] = augmented_mem_mask
if mask is None:
length_array = mx.nd.arange(step_input.shape[1], ctx=step_input.context)
mask = mx.nd.broadcast_lesser_equal(
length_array.reshape((1, -1)),
length_array.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0),
axis=0, size=step_input.shape[0])
steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context)
states.append(steps)
step_output, step_additional_outputs = \
super(TransformerDecoder, self).forward(step_input * math.sqrt(step_input.shape[-1]), #pylint: disable=too-many-function-args
states, mask)
states = states[:-1]
if has_mem_mask:
states[-1] = mem_mask
new_states = [step_input] + states
# If it is in testing, only output the last one
if len(input_shape) == 2:
step_output = step_output[:, -1, :]
return step_output, new_states, step_additional_outputs
def hybrid_forward(self, F, step_input, states, mask=None, position_weight=None): #pylint: disable=arguments-differ
"""
Parameters
----------
step_input : NDArray or Symbol, Shape (batch_size, length, C_in)
states : list of NDArray or Symbol
mask : NDArray or Symbol
position_weight : NDArray or Symbol
Returns
-------
step_output : NDArray or Symbol
The output of the decoder. Shape is (batch_size, length, C_out)
step_additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
has_mem_mask = (len(states) == 3)
if has_mem_mask:
mem_value, mem_mask, steps = states
else:
mem_value, steps = states
mem_mask = None
# Positional Encoding
step_input = F.broadcast_add(step_input,
F.expand_dims(F.Embedding(steps,
position_weight,
self._max_length,
self._units),
axis=0))
step_input = self.dropout_layer(step_input)
step_input = self.layer_norm(step_input)
inputs = step_input
outputs = inputs
step_additional_outputs = []
attention_weights_l = []
for cell in self.transformer_cells:
outputs, attention_weights = cell(inputs, mem_value, mask, mem_mask)
if self._output_attention:
attention_weights_l.append(attention_weights)
inputs = outputs
if self._output_attention:
step_additional_outputs.extend(attention_weights_l)
return outputs, step_additional_outputs
def get_transformer_encoder_decoder(num_layers=2,
num_heads=8, scaled=True,
units=512, hidden_size=2048, dropout=0.0, use_residual=True,
max_src_length=50, max_tgt_length=50,
weight_initializer=None, bias_initializer='zeros',
prefix='transformer_', params=None):
"""Build a pair of Parallel GNMT encoder/decoder
Parameters
----------
num_layers : int
num_heads : int
scaled : bool
units : int
hidden_size : int
dropout : float
use_residual : bool
max_src_length : int
max_tgt_length : int
weight_initializer : mx.init.Initializer or None
bias_initializer : mx.init.Initializer or None
prefix : str, default 'transformer_'
Prefix for name of `Block`s.
params : Parameter or None
Container for weight sharing between layers.
Created if `None`.
Returns
-------
encoder : TransformerEncoder
decoder :TransformerDecoder
"""
encoder = TransformerEncoder(num_layers=num_layers,
num_heads=num_heads,
max_length=max_src_length,
units=units,
hidden_size=hidden_size,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix=prefix + 'enc_', params=params)
decoder = TransformerDecoder(num_layers=num_layers,
num_heads=num_heads,
max_length=max_tgt_length,
units=units,
hidden_size=hidden_size,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix=prefix + 'dec_', params=params)
return encoder, decoder
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment