Skip to content

Instantly share code, notes, and snippets.

@ThomasDelteil
Last active November 22, 2018 02:55
Show Gist options
  • Save ThomasDelteil/4274c1946e958615df915d5d30908c19 to your computer and use it in GitHub Desktop.
Save ThomasDelteil/4274c1946e958615df915d5d30908c19 to your computer and use it in GitHub Desktop.
TextDenoising
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Encoder and decoder usded in sequence-to-sequence learning."""
__all__ = ['TransformerEncoder', 'TransformerDecoder',
'get_transformer_encoder_decoder']
import math
import os
import warnings
from functools import partial
from gluonnlp.model import AttentionCell, MLPAttentionCell, DotProductAttentionCell, MultiHeadAttentionCell
import gluonnlp as nlp
import mxnet as mx
from mxnet.gluon import rnn
from mxnet.gluon.block import Block
from mxnet import cpu, gluon
from mxnet.gluon import nn
from mxnet.gluon.block import HybridBlock
import numpy as np
def _list_bcast_where(F, mask, new_val_l, old_val_l):
"""Broadcast where. Implements out[i] = new_val[i] * mask + old_val[i] * (1 - mask)
Parameters
----------
F : symbol or ndarray
mask : Symbol or NDArray
new_val_l : list of Symbols or list of NDArrays
old_val_l : list of Symbols or list of NDArrays
Returns
-------
out_l : list of Symbols or list of NDArrays
"""
return [F.broadcast_mul(new_val, mask) + F.broadcast_mul(old_val, 1 - mask)
for new_val, old_val in zip(new_val_l, old_val_l)]
def _get_cell_type(cell_type):
"""Get the object type of the cell by parsing the input
Parameters
----------
cell_type : str or type
Returns
-------
cell_constructor: type
The constructor of the RNNCell
"""
if isinstance(cell_type, str):
if cell_type == 'lstm':
return rnn.LSTMCell
elif cell_type == 'gru':
return rnn.GRUCell
elif cell_type == 'relu_rnn':
return partial(rnn.RNNCell, activation='relu')
elif cell_type == 'tanh_rnn':
return partial(rnn.RNNCell, activation='tanh')
else:
raise NotImplementedError
else:
return cell_type
def _get_attention_cell(attention_cell, units=None,
scaled=True, num_heads=None,
use_bias=False, dropout=0.0):
"""
Parameters
----------
attention_cell : AttentionCell or str
units : int or None
Returns
-------
attention_cell : AttentionCell
"""
if isinstance(attention_cell, str):
if attention_cell == 'scaled_luong':
return DotProductAttentionCell(units=units, scaled=True, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=True)
elif attention_cell == 'scaled_dot':
return DotProductAttentionCell(units=units, scaled=True, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=False)
elif attention_cell == 'dot':
return DotProductAttentionCell(units=units, scaled=False, normalized=False,
use_bias=use_bias, dropout=dropout, luong_style=False)
elif attention_cell == 'cosine':
return DotProductAttentionCell(units=units, scaled=False, use_bias=use_bias,
dropout=dropout, normalized=True)
elif attention_cell == 'mlp':
return MLPAttentionCell(units=units, normalized=False)
elif attention_cell == 'normed_mlp':
return MLPAttentionCell(units=units, normalized=True)
elif attention_cell == 'multi_head':
base_cell = DotProductAttentionCell(scaled=scaled, dropout=dropout)
return MultiHeadAttentionCell(base_cell=base_cell, query_units=units, use_bias=use_bias,
key_units=units, value_units=units, num_heads=num_heads)
else:
raise NotImplementedError
else:
assert isinstance(attention_cell, AttentionCell),\
'attention_cell must be either string or AttentionCell. Received attention_cell={}'\
.format(attention_cell)
return attention_cell
def _nested_sequence_last(data, valid_length):
"""
Parameters
----------
data : nested container of NDArrays/Symbols
The input data. Each element will have shape (batch_size, ...)
valid_length : NDArray or Symbol
Valid length of the sequences. Shape (batch_size,)
Returns
-------
data_last: nested container of NDArrays/Symbols
The last valid element in the sequence.
"""
assert isinstance(data, list)
if isinstance(data[0], (mx.sym.Symbol, mx.nd.NDArray)):
F = mx.sym if isinstance(data[0], mx.sym.Symbol) else mx.ndarray
return F.SequenceLast(F.stack(*data, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
elif isinstance(data[0], list):
ret = []
for i in range(len(data[0])):
ret.append(_nested_sequence_last([ele[i] for ele in data], valid_length))
return ret
else:
raise NotImplementedError
class Seq2SeqEncoder(Block):
r"""Base class of the encoders in sequence to sequence learning models.
"""
def __call__(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ
"""Encode the input sequence.
Parameters
----------
inputs : NDArray
The input sequence, Shape (batch_size, length, C_in).
valid_length : NDArray or None, default None
The valid length of the input sequence, Shape (batch_size,). This is used when the
input sequences are padded. If set to None, all elements in the sequence are used.
states : list of NDArrays or None, default None
List that contains the initial states of the encoder.
Returns
-------
outputs : list
Outputs of the encoder.
"""
return super(Seq2SeqEncoder, self).__call__(inputs, valid_length, states)
def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ
raise NotImplementedError
class Seq2SeqDecoder(Block):
r"""Base class of the decoders in sequence to sequence learning models.
In the forward function, it generates the one-step-ahead decoding output.
"""
def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
r"""Generates the initial decoder states based on the encoder outputs.
Parameters
----------
encoder_outputs : list of NDArrays
encoder_valid_length : NDArray or None
Returns
-------
decoder_states : list
"""
raise NotImplementedError
def decode_seq(self, inputs, states, valid_length=None):
r"""Given the inputs and the context computed by the encoder,
generate the new states. This is usually used in the training phase where we set the inputs
to be the target sequence.
Parameters
----------
inputs : NDArray
The input embeddings. Shape (batch_size, length, C_in)
states : list
The initial states of the decoder.
valid_length : NDArray or None
valid length of the inputs. Shape (batch_size,)
Returns
-------
output : NDArray
The output of the decoder. Shape is (batch_size, length, C_out)
states: list
The new states of the decoder
additional_outputs : list
Additional outputs of the decoder, e.g, the attention weights
"""
raise NotImplementedError
def __call__(self, step_input, states): #pylint: disable=arguments-differ
r"""One-step decoding of the input
Parameters
----------
step_input : NDArray
Shape (batch_size, C_in)
states : list
The previous states of the decoder
Returns
-------
step_output : NDArray
Shape (batch_size, C_out)
states : list
step_additional_outputs : list
Additional outputs of the step, e.g, the attention weights
"""
return super(Seq2SeqDecoder, self).__call__(step_input, states)
def forward(self, step_input, states): #pylint: disable=arguments-differ
raise NotImplementedError
def _position_encoding_init(max_length, dim):
""" Init the sinusoid position encoding table """
position_enc = np.arange(max_length).reshape((-1, 1)) \
/ (np.power(10000, (2. / dim) * np.arange(dim).reshape((1, -1))))
# Apply the cosine to even columns and sin to odds.
position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1
return position_enc
class PositionwiseFFN(HybridBlock):
"""Structure of the Positionwise Feed-Forward Neural Network.
Parameters
----------
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
dropout : float
use_residual : bool
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
activation : str, default 'relu'
Activation function
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, units=512, hidden_size=2048, dropout=0.0, use_residual=True,
weight_initializer=None, bias_initializer='zeros', activation='relu',
prefix=None, params=None):
super(PositionwiseFFN, self).__init__(prefix=prefix, params=params)
self._hidden_size = hidden_size
self._units = units
self._use_residual = use_residual
with self.name_scope():
self.ffn_1 = nn.Dense(units=hidden_size, flatten=False,
activation=activation,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='ffn_1_')
self.ffn_2 = nn.Dense(units=units, flatten=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='ffn_2_')
self.dropout_layer = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm()
def hybrid_forward(self, F, inputs): # pylint: disable=arguments-differ
# pylint: disable=unused-argument
"""Position-wise encoding of the inputs.
Parameters
----------
inputs : Symbol or NDArray
Input sequence. Shape (batch_size, length, C_in)
Returns
-------
outputs : Symbol or NDArray
Shape (batch_size, length, C_out)
"""
outputs = self.ffn_1(inputs)
outputs = self.ffn_2(outputs)
outputs = self.dropout_layer(outputs)
if self._use_residual:
outputs = outputs + inputs
outputs = self.layer_norm(outputs)
return outputs
class TransformerEncoderCell(HybridBlock):
"""Structure of the Transformer Encoder Cell.
Parameters
----------
attention_cell : AttentionCell or str, default 'multi_head'
Arguments of the attention cell.
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp'
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
num_heads : int
Number of heads in multi-head attention
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, attention_cell='multi_head', units=128,
hidden_size=512, num_heads=4, scaled=True,
dropout=0.0, use_residual=True, output_attention=False,
weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(TransformerEncoderCell, self).__init__(prefix=prefix, params=params)
self._units = units
self._num_heads = num_heads
self._dropout = dropout
self._use_residual = use_residual
self._output_attention = output_attention
with self.name_scope():
self.dropout_layer = nn.Dropout(dropout)
self.attention_cell = _get_attention_cell(attention_cell,
units=units,
num_heads=num_heads,
scaled=scaled,
dropout=dropout)
self.proj = nn.Dense(units=units, flatten=False, use_bias=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='proj_')
self.ffn = PositionwiseFFN(hidden_size=hidden_size, units=units,
use_residual=use_residual, dropout=dropout,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
self.layer_norm = nn.LayerNorm()
def hybrid_forward(self, F, inputs, mask=None): # pylint: disable=arguments-differ
# pylint: disable=unused-argument
"""Transformer Encoder Attention Cell.
Parameters
----------
inputs : Symbol or NDArray
Input sequence. Shape (batch_size, length, C_in)
mask : Symbol or NDArray or None
Mask for inputs. Shape (batch_size, length, length)
Returns
-------
encoder_cell_outputs: list
Outputs of the encoder cell. Contains:
- outputs of the transformer encoder cell. Shape (batch_size, length, C_out)
- additional_outputs of all the transformer encoder cell
"""
outputs, attention_weights =\
self.attention_cell(inputs, inputs, inputs, mask)
outputs = self.proj(outputs)
outputs = self.dropout_layer(outputs)
if self._use_residual:
outputs = outputs + inputs
outputs = self.layer_norm(outputs)
outputs = self.ffn(outputs)
additional_outputs = []
if self._output_attention:
additional_outputs.append(attention_weights)
return outputs, additional_outputs
class TransformerDecoderCell(HybridBlock):
"""Structure of the Transformer Decoder Cell.
Parameters
----------
attention_cell : AttentionCell or str, default 'multi_head'
Arguments of the attention cell.
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp'
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
num_heads : int
Number of heads in multi-head attention
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, attention_cell='multi_head', units=128,
hidden_size=512, num_heads=4, scaled=True,
dropout=0.0, use_residual=True, output_attention=False,
weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(TransformerDecoderCell, self).__init__(prefix=prefix, params=params)
self._units = units
self._num_heads = num_heads
self._dropout = dropout
self._use_residual = use_residual
self._output_attention = output_attention
self._scaled = scaled
with self.name_scope():
self.dropout_layer = nn.Dropout(dropout)
self.attention_cell_in = _get_attention_cell(attention_cell,
units=units,
num_heads=num_heads,
scaled=scaled,
dropout=dropout)
self.attention_cell_inter = _get_attention_cell(attention_cell,
units=units,
num_heads=num_heads,
scaled=scaled,
dropout=dropout)
self.proj_in = nn.Dense(units=units, flatten=False,
use_bias=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='proj_in_')
self.proj_inter = nn.Dense(units=units, flatten=False,
use_bias=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix='proj_inter_')
self.ffn = PositionwiseFFN(hidden_size=hidden_size,
units=units,
use_residual=use_residual,
dropout=dropout,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
self.layer_norm_in = nn.LayerNorm()
self.layer_norm_inter = nn.LayerNorm()
def hybrid_forward(self, F, inputs, mem_value, mask=None, mem_mask=None): #pylint: disable=unused-argument
# pylint: disable=arguments-differ
"""Transformer Decoder Attention Cell.
Parameters
----------
inputs : Symbol or NDArray
Input sequence. Shape (batch_size, length, C_in)
mem_value : Symbol or NDArrays
Memory value, i.e. output of the encoder. Shape (batch_size, mem_length, C_in)
mask : Symbol or NDArray or None
Mask for inputs. Shape (batch_size, length, length)
mem_mask : Symbol or NDArray or None
Mask for mem_value. Shape (batch_size, length, mem_length)
Returns
-------
decoder_cell_outputs: list
Outputs of the decoder cell. Contains:
- outputs of the transformer decoder cell. Shape (batch_size, length, C_out)
- additional_outputs of all the transformer decoder cell
"""
outputs, attention_in_outputs =\
self.attention_cell_in(inputs, inputs, inputs, mask)
outputs = self.proj_in(outputs)
outputs = self.dropout_layer(outputs)
if self._use_residual:
outputs = outputs + inputs
outputs = self.layer_norm_in(outputs)
inputs = outputs
outputs, attention_inter_outputs = \
self.attention_cell_inter(inputs, mem_value, mem_value, mem_mask)
outputs = self.proj_inter(outputs)
outputs = self.dropout_layer(outputs)
if self._use_residual:
outputs = outputs + inputs
outputs = self.layer_norm_inter(outputs)
outputs = self.ffn(outputs)
additional_outputs = []
if self._output_attention:
additional_outputs.append(attention_in_outputs)
additional_outputs.append(attention_inter_outputs)
return outputs, additional_outputs
class TransformerEncoder(HybridBlock, Seq2SeqEncoder):
"""Structure of the Transformer Encoder.
Parameters
----------
attention_cell : AttentionCell or str, default 'multi_head'
Arguments of the attention cell.
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp'
num_layers : int
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
max_length : int
Maximum length of the input sequence
num_heads : int
Number of heads in multi-head attention
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, attention_cell='multi_head', num_layers=2,
units=512, hidden_size=2048, max_length=50,
num_heads=4, scaled=True, dropout=0.0,
use_residual=True, output_attention=False,
weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(TransformerEncoder, self).__init__(prefix=prefix, params=params)
assert units % num_heads == 0,\
'In TransformerEncoder, The units should be divided exactly ' \
'by the number of heads. Received units={}, num_heads={}' \
.format(units, num_heads)
self._num_layers = num_layers
self._max_length = max_length
self._num_heads = num_heads
self._units = units
self._hidden_size = hidden_size
self._output_attention = output_attention
self._dropout = dropout
self._use_residual = use_residual
self._scaled = scaled
with self.name_scope():
self.dropout_layer = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm()
self.position_weight = self.params.get_constant('const',
_position_encoding_init(max_length,
units))
self.transformer_cells = nn.HybridSequential()
for i in range(num_layers):
self.transformer_cells.add(
TransformerEncoderCell(
units=units,
hidden_size=hidden_size,
num_heads=num_heads,
attention_cell=attention_cell,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dropout=dropout,
use_residual=use_residual,
scaled=scaled,
output_attention=output_attention,
prefix='transformer%d_' % i))
def __call__(self, inputs, states=None, valid_length=None): #pylint: disable=arguments-differ
"""Encoder the inputs given the states and valid sequence length.
Parameters
----------
inputs : NDArray
Input sequence. Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of initial states and masks
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)
Returns
-------
encoder_outputs: list
Outputs of the encoder. Contains:
- outputs of the transformer encoder. Shape (batch_size, length, C_out)
- additional_outputs of all the transformer encoder
"""
return super(TransformerEncoder, self).__call__(inputs, states, valid_length)
def forward(self, inputs, states=None, valid_length=None, steps=None): # pylint: disable=arguments-differ
"""
Parameters
----------
inputs : NDArray, Shape(batch_size, length, C_in)
states : list of NDArray
valid_length : NDArray
steps : NDArray
Stores value [0, 1, ..., length].
It is used for lookup in positional encoding matrix
Returns
-------
outputs : NDArray
The output of the encoder. Shape is (batch_size, length, C_out)
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, length) or
(batch_size, num_heads, length, length)
"""
length = inputs.shape[1]
if valid_length is not None:
mask = mx.nd.broadcast_lesser(
mx.nd.arange(length, ctx=valid_length.context).reshape((1, -1)),
valid_length.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=1), axis=1, size=length)
if states is None:
states = [mask]
else:
states.append(mask)
inputs = inputs * math.sqrt(inputs.shape[-1])
steps = mx.nd.arange(length, ctx=inputs.context)
if states is None:
states = [steps]
else:
states.append(steps)
if valid_length is not None:
step_output, additional_outputs =\
super(TransformerEncoder, self).forward(inputs, states, valid_length)
else:
step_output, additional_outputs =\
super(TransformerEncoder, self).forward(inputs, states)
return step_output, additional_outputs
def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_weight=None): # pylint: disable=arguments-differ
"""
Parameters
----------
inputs : NDArray or Symbol, Shape(batch_size, length, C_in)
states : list of NDArray or Symbol
valid_length : NDArray or Symbol
position_weight : NDArray or Symbol
Returns
-------
outputs : NDArray or Symbol
The output of the encoder. Shape is (batch_size, length, C_out)
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, length) or
(batch_size, num_heads, length, length)
"""
if states is not None:
steps = states[-1]
# Positional Encoding
inputs = F.broadcast_add(inputs, F.expand_dims(F.Embedding(steps, position_weight,
self._max_length,
self._units), axis=0))
inputs = self.dropout_layer(inputs)
inputs = self.layer_norm(inputs)
outputs = inputs
if valid_length is not None:
mask = states[-2]
else:
mask = None
additional_outputs = []
for cell in self.transformer_cells:
outputs, attention_weights = cell(inputs, mask)
inputs = outputs
if self._output_attention:
additional_outputs.append(attention_weights)
if valid_length is not None:
outputs = F.SequenceMask(outputs, sequence_length=valid_length,
use_sequence_length=True, axis=1)
return outputs, additional_outputs
class TransformerDecoder(HybridBlock, Seq2SeqDecoder):
"""Structure of the Transformer Decoder.
Parameters
----------
attention_cell : AttentionCell or str, default 'multi_head'
Arguments of the attention cell.
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp'
num_layers : int
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
max_length : int
Maximum length of the input sequence. This is used for constructing position encoding
num_heads : int
Number of heads in multi-head attention
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
def __init__(self, attention_cell='multi_head', num_layers=2,
units=128, hidden_size=2048, max_length=50,
num_heads=4, scaled=True, dropout=0.0,
use_residual=True, output_attention=False,
weight_initializer=None, bias_initializer='zeros',
prefix=None, params=None):
super(TransformerDecoder, self).__init__(prefix=prefix, params=params)
assert units % num_heads == 0, 'In TransformerDecoder, the units should be divided ' \
'exactly by the number of heads. Received units={}, ' \
'num_heads={}'.format(units, num_heads)
self._num_layers = num_layers
self._units = units
self._hidden_size = hidden_size
self._num_states = num_heads
self._max_length = max_length
self._dropout = dropout
self._use_residual = use_residual
self._output_attention = output_attention
self._scaled = scaled
with self.name_scope():
self.dropout_layer = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm()
self.position_weight = self.params.get_constant('const',
_position_encoding_init(max_length,
units))
self.transformer_cells = nn.HybridSequential()
for i in range(num_layers):
self.transformer_cells.add(
TransformerDecoderCell(
units=units,
hidden_size=hidden_size,
num_heads=num_heads,
attention_cell=attention_cell,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
output_attention=output_attention,
prefix='transformer%d_' % i))
def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
"""Initialize the state from the encoder outputs.
Parameters
----------
encoder_outputs : list
encoder_valid_length : NDArray or None
Returns
-------
decoder_states : list
The decoder states, includes:
- mem_value : NDArray
- mem_masks : NDArray, optional
"""
mem_value = encoder_outputs
decoder_states = [mem_value]
mem_length = mem_value.shape[1]
if encoder_valid_length is not None:
mem_masks = mx.nd.broadcast_lesser(
mx.nd.arange(mem_length, ctx=encoder_valid_length.context).reshape((1, -1)),
encoder_valid_length.reshape((-1, 1)))
decoder_states.append(mem_masks)
self._encoder_valid_length = encoder_valid_length
return decoder_states
def decode_seq(self, inputs, states, valid_length=None):
"""Decode the decoder inputs. This function is only used for training.
Parameters
----------
inputs : NDArray, Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of decoder states
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)
Returns
-------
output : NDArray, Shape (batch_size, length, C_out)
states : list
The decoder states, includes:
- mem_value : NDArray
- mem_masks : NDArray, optional
additional_outputs : list of list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
batch_size = inputs.shape[0]
length = inputs.shape[1]
length_array = mx.nd.arange(length, ctx=inputs.context)
mask = mx.nd.broadcast_lesser_equal(
length_array.reshape((1, -1)),
length_array.reshape((-1, 1)))
if valid_length is not None:
batch_mask = mx.nd.broadcast_lesser(
mx.nd.arange(length, ctx=valid_length.context).reshape((1, -1)),
valid_length.reshape((-1, 1)))
mask = mx.nd.broadcast_mul(mx.nd.expand_dims(batch_mask, -1),
mx.nd.expand_dims(mask, 0))
else:
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=batch_size)
states = [None] + states
output, states, additional_outputs = self.forward(inputs, states, mask)
states = states[1:]
if valid_length is not None:
output = mx.nd.SequenceMask(output,
sequence_length=valid_length,
use_sequence_length=True,
axis=1)
return output, states, additional_outputs
def __call__(self, step_input, states): #pylint: disable=arguments-differ
"""One-step-ahead decoding of the Transformer decoder.
Parameters
----------
step_input : NDArray
states : list of NDArray
Returns
-------
step_output : NDArray
The output of the decoder.
In the train mode, Shape is (batch_size, length, C_out)
In the test mode, Shape is (batch_size, C_out)
new_states: list
Includes
- last_embeds : NDArray or None
It is only given during testing
- mem_value : NDArray
- mem_masks : NDArray, optional
step_additional_outputs : list of list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
return super(TransformerDecoder, self).__call__(step_input, states)
def forward(self, step_input, states, mask=None): #pylint: disable=arguments-differ, missing-docstring
input_shape = step_input.shape
mem_mask = None
# If it is in testing, transform input tensor to a tensor with shape NTC
# Otherwise remove the None in states.
if len(input_shape) == 2:
if self._encoder_valid_length is not None:
has_last_embeds = len(states) == 3
else:
has_last_embeds = len(states) == 2
if has_last_embeds:
last_embeds = states[0]
step_input = mx.nd.concat(last_embeds,
mx.nd.expand_dims(step_input, axis=1),
dim=1)
states = states[1:]
else:
step_input = mx.nd.expand_dims(step_input, axis=1)
elif states[0] is None:
states = states[1:]
has_mem_mask = (len(states) == 2)
if has_mem_mask:
_, mem_mask = states
augmented_mem_mask = mx.nd.expand_dims(mem_mask, axis=1)\
.broadcast_axes(axis=1, size=step_input.shape[1])
states[-1] = augmented_mem_mask
if mask is None:
length_array = mx.nd.arange(step_input.shape[1], ctx=step_input.context)
mask = mx.nd.broadcast_lesser_equal(
length_array.reshape((1, -1)),
length_array.reshape((-1, 1)))
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0),
axis=0, size=step_input.shape[0])
steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context)
states.append(steps)
step_output, step_additional_outputs = \
super(TransformerDecoder, self).forward(step_input * math.sqrt(step_input.shape[-1]), #pylint: disable=too-many-function-args
states, mask)
states = states[:-1]
if has_mem_mask:
states[-1] = mem_mask
new_states = [step_input] + states
# If it is in testing, only output the last one
if len(input_shape) == 2:
step_output = step_output[:, -1, :]
return step_output, new_states, step_additional_outputs
def hybrid_forward(self, F, step_input, states, mask=None, position_weight=None): #pylint: disable=arguments-differ
"""
Parameters
----------
step_input : NDArray or Symbol, Shape (batch_size, length, C_in)
states : list of NDArray or Symbol
mask : NDArray or Symbol
position_weight : NDArray or Symbol
Returns
-------
step_output : NDArray or Symbol
The output of the decoder. Shape is (batch_size, length, C_out)
step_additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
has_mem_mask = (len(states) == 3)
if has_mem_mask:
mem_value, mem_mask, steps = states
else:
mem_value, steps = states
mem_mask = None
# Positional Encoding
step_input = F.broadcast_add(step_input,
F.expand_dims(F.Embedding(steps,
position_weight,
self._max_length,
self._units),
axis=0))
step_input = self.dropout_layer(step_input)
step_input = self.layer_norm(step_input)
inputs = step_input
outputs = inputs
step_additional_outputs = []
attention_weights_l = []
for cell in self.transformer_cells:
outputs, attention_weights = cell(inputs, mem_value, mask, mem_mask)
if self._output_attention:
attention_weights_l.append(attention_weights)
inputs = outputs
if self._output_attention:
step_additional_outputs.extend(attention_weights_l)
return outputs, step_additional_outputs
def get_transformer_encoder_decoder(num_layers=2,
num_heads=8, scaled=True,
units=512, hidden_size=2048, dropout=0.0, use_residual=True,
max_src_length=50, max_tgt_length=50,
weight_initializer=None, bias_initializer='zeros',
prefix='transformer_', params=None):
"""Build a pair of Parallel GNMT encoder/decoder
Parameters
----------
num_layers : int
num_heads : int
scaled : bool
units : int
hidden_size : int
dropout : float
use_residual : bool
max_src_length : int
max_tgt_length : int
weight_initializer : mx.init.Initializer or None
bias_initializer : mx.init.Initializer or None
prefix : str, default 'transformer_'
Prefix for name of `Block`s.
params : Parameter or None
Container for weight sharing between layers.
Created if `None`.
Returns
-------
encoder : TransformerEncoder
decoder :TransformerDecoder
"""
encoder = TransformerEncoder(num_layers=num_layers,
num_heads=num_heads,
max_length=max_src_length,
units=units,
hidden_size=hidden_size,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix=prefix + 'enc_', params=params)
decoder = TransformerDecoder(num_layers=num_layers,
num_heads=num_heads,
max_length=max_tgt_length,
units=units,
hidden_size=hidden_size,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix=prefix + 'dec_', params=params)
return encoder, decoder
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text Denoising\n",
"\n",
"Inspired by \"Neural Networks for Text Correction and Completion in Keyboard Decoding\" by Shaona Ghosh and Per Ola Kristensson. https://arxiv.org/pdf/1709.06429.pdf"
]
},
{
"cell_type": "code",
"execution_count": 146,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"import os\n",
"import random\n",
"import string"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from gluoncv.data.batchify import Tuple, Stack, Append\n",
"import mxnet as mx\n",
"from mxnet import gluon, autograd\n",
"from mxnet.gluon import HybridBlock\n",
"from mxnet.gluon.loss import SoftmaxCELoss\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"gpu(0)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()\n",
"ctx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {},
"outputs": [],
"source": [
"if not os.path.isdir('dataset'):\n",
" os.makedirs('dataset')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'dataset/typo/alicewonder.txt'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"typo_filepath = 'dataset/typo/typo-corpus-r1.txt'\n",
"text_filepath = 'dataset/typo/alicewonder.txt'\n",
"mx.test_utils.download('http://luululu.com/tweet/typo-corpus-r1.txt', dirname='dataset/typo')\n",
"mx.test_utils.download('http://textfiles.com/etext/FICTION/alicewonder.txt', dirname='dataset/typo')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"ALPHABET = ['<PAD>', '<BOS>', '<EOS>']+list(' ' + string.ascii_letters + string.digits + string.punctuation)\n",
"ALPHABET_INDEX = {letter: index for index, letter in enumerate(ALPHABET)} # { a: 0, b: 1, etc}\n",
"FEATURE_LEN = 150 # max-length in characters for one document\n",
"NUM_WORKERS = 8 # number of workers used in the data loading\n",
"BATCH_SIZE = 128 # number of documents per batch\n",
"PAD = 0\n",
"BOS = 1\n",
"EOS = 2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class NoisyTextDataset(mx.gluon.data.Dataset):\n",
" def __init__(self, text_filepath, typo_filepath, replace_proba=0.2, is_train=True):\n",
" self.replace_proba = replace_proba\n",
" self.typo_dict = self._process_typo(typo_filepath)\n",
" self.text = self._process_text(text_filepath, is_train)\n",
" \n",
" def _process_text(self, filename, is_train):\n",
" with open(text_filepath, 'r') as f:\n",
" text = []\n",
" for line in f.readlines():\n",
" line = line.replace('\\n', '').strip()\n",
" if line != '':\n",
" text.append(line)\n",
" \n",
" split_index = int(0.8*len(text))\n",
" if is_train:\n",
" text = text[:split_index]\n",
" else:\n",
" text = text[split_index:]\n",
" return text\n",
"\n",
" \n",
" def _process_typo(self, filename):\n",
" \"\"\"\n",
" This function loads the typo dataset and generate the \n",
" probability distribution of typos for each valid word\n",
" \"\"\"\n",
" typo_dict = defaultdict(lambda : defaultdict(float))\n",
" with open(filename, 'r') as f:\n",
" lines = f.readlines()\n",
" for line in lines:\n",
" typo, correct = line.split('\\t')[0:2]\n",
" typo_dict[correct][typo] += 1\n",
" for _, correct_word in typo_dict.items():\n",
" total = 0\n",
" for _, count in correct_word.items():\n",
" total += count\n",
" previous_value = 0.\n",
" for wrong_word in correct_word:\n",
" correct_word[wrong_word] = correct_word[wrong_word] / total + previous_value\n",
" previous_value = correct_word[wrong_word]\n",
" return typo_dict\n",
" \n",
" def _transform_line(self, line):\n",
" \"\"\"\n",
" replace words that are in the typo dataset with a typo\n",
" with a probability `self.replace_proba`\n",
" \"\"\"\n",
" output = []\n",
" proba_replace = 0.2\n",
" for word in self._pre_process_line(line):\n",
" if word.lower() in self.typo_dict and word != '':\n",
" if random.random() < self.replace_proba:\n",
" draw = random.random()\n",
" for typo, value in self.typo_dict[word].items():\n",
" if draw < value:\n",
" word = self._match_caps(word, typo)\n",
" break\n",
" output.append(word)\n",
" else:\n",
" output.append(word)\n",
"\n",
" return self._post_process_line(output)\n",
" \n",
" def _pre_process_line(self, line):\n",
" line = line.replace('\\n', '')\n",
" for char in string.punctuation:\n",
" if char in line:\n",
" line = line.replace(char, ' '+char+' ')\n",
" return line.split(' ')\n",
" \n",
" def _post_process_line(self, words):\n",
" output = ' '.join(words)\n",
" for char in string.punctuation:\n",
" output = output.replace(' '+char+' ', char)\n",
" return output\n",
" \n",
" def _match_caps(self, original, typo):\n",
" if original.isupper():\n",
" return typo.upper()\n",
" elif original.istitle():\n",
" return typo.capitalize()\n",
" else:\n",
" return typo\n",
" \n",
" def __getitem__(self, idx):\n",
" line = self.text[idx]\n",
" line_typo = self._transform_line(line)\n",
" return line_typo, line\n",
"\n",
" def __len__(self):\n",
" return len(self.text)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def encode(text, src=True):\n",
" encoded = np.ones(FEATURE_LEN, dtype='float32') * PAD\n",
" text = text[:FEATURE_LEN-2]\n",
" i = 0\n",
" if not src:\n",
" encoded[0] = BOS\n",
" i = 1\n",
" for letter in text:\n",
" if letter in ALPHABET_INDEX:\n",
" encoded[i] = ALPHABET_INDEX[letter]\n",
" i += 1\n",
" encoded[i] = EOS\n",
" return encoded, np.array([len(text)+2]).astype('float32')\n",
"\n",
"def transform(data, label):\n",
" src, src_valid_length = encode(data, src=True)\n",
" tgt, tgt_valid_length = encode(label, src=False)\n",
" return src, src_valid_length, tgt, tgt_valid_length, data, label"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"dataset_train = NoisyTextDataset(text_filepath=text_filepath, typo_filepath=typo_filepath, replace_proba=0.4, is_train=True).transform(transform)\n",
"dataset_test = NoisyTextDataset(text_filepath=text_filepath, typo_filepath=typo_filepath, replace_proba=0.4, is_train=False).transform(transform)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([16., 11., 3., 7., 8., 4., 21., 77., 72., 3., 22., 4., 12.,\n",
" 7., 7., 3., 30., 15., 12., 6., 8., 77., 3., 22., 8., 21.,\n",
" 12., 18., 22., 24., 15., 28., 77., 3., 93., 38., 72., 15., 15.,\n",
" 3., 11., 4., 25., 8., 3., 17., 18., 23., 11., 12., 10., 17.,\n",
" 3., 16., 18., 21., 8., 3., 23., 18., 3., 7., 18., 12., 2.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" array([66.], dtype=float32),\n",
" array([ 1., 16., 28., 3., 7., 8., 4., 21., 77., 72., 3., 22., 4.,\n",
" 12., 7., 3., 30., 15., 12., 6., 8., 77., 3., 22., 8., 21.,\n",
" 12., 18., 24., 22., 15., 28., 77., 3., 93., 38., 72., 15., 15.,\n",
" 3., 11., 4., 25., 8., 3., 17., 18., 23., 11., 12., 17., 10.,\n",
" 3., 16., 18., 21., 8., 3., 23., 18., 3., 7., 18., 2., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" array([64.], dtype=float32),\n",
" \"mh dear,' saidd Alice, seriosuly, `I'll have nothign more to doi\",\n",
" \"my dear,' said Alice, seriously, `I'll have nothing more to do\")"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset_train[random.randint(0, len(dataset_train)-1)]"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"def decode(text):\n",
" output = []\n",
" for val in text:\n",
" output.append(ALPHABET[int(val)])\n",
" if val == EOS:\n",
" break\n",
" return ''.join(output)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def batchify_list(elem):\n",
" output = []\n",
" for e in elem:\n",
" output.append(elem)\n",
" return output\n",
" \n",
"batchify = Tuple(Stack(), Stack(), Stack(), Stack(), batchify_list, batchify_list)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"train_data = gluon.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, last_batch='rollover', batchify_fn=batchify)\n",
"test_data = gluon.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True, last_batch='rollover', batchify_fn=batchify)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from utils.encoder_decoder import get_transformer_encoder_decoder"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"class SoftmaxCEMaskedLoss(SoftmaxCELoss):\n",
" \"\"\"Wrapper of the SoftmaxCELoss that supports valid_length as the input\n",
" \"\"\"\n",
" def hybrid_forward(self, F, pred, label, valid_length): # pylint: disable=arguments-differ\n",
" \"\"\"\n",
" Parameters\n",
" ----------\n",
" F\n",
" pred : Symbol or NDArray\n",
" Shape (batch_size, length, V)\n",
" label : Symbol or NDArray\n",
" Shape (batch_size, length)\n",
" valid_length : Symbol or NDArray\n",
" Shape (batch_size, )\n",
" Returns\n",
" -------\n",
" loss : Symbol or NDArray\n",
" Shape (batch_size,)\n",
" \"\"\"\n",
" if self._sparse_label:\n",
" sample_weight = F.cast(F.expand_dims(F.ones_like(label), axis=-1), dtype=np.float32)\n",
" else:\n",
" sample_weight = F.ones_like(label)\n",
" sample_weight = F.SequenceMask(sample_weight,\n",
" sequence_length=valid_length,\n",
" use_sequence_length=True,\n",
" axis=1)\n",
" return super(SoftmaxCEMaskedLoss, self).hybrid_forward(F, pred, label, sample_weight)\n",
"\n",
"# pylint: disable=unused-argument\n",
"class _SmoothingWithDim(mx.operator.CustomOp):\n",
" def __init__(self, epsilon=0.1, axis=-1):\n",
" super(_SmoothingWithDim, self).__init__(True)\n",
" self._epsilon = epsilon\n",
" self._axis = axis\n",
"\n",
" def forward(self, is_train, req, in_data, out_data, aux):\n",
" inputs = in_data[0]\n",
" outputs = ((1 - self._epsilon) * inputs) + (self._epsilon / float(inputs.shape[self._axis]))\n",
" self.assign(out_data[0], req[0], outputs)\n",
"\n",
" def backward(self, req, out_grad, in_data, out_data, in_grad, aux):\n",
" self.assign(in_grad[0], req[0], (1 - self._epsilon) * out_grad[0])\n",
"\n",
"\n",
"@mx.operator.register('_smoothing_with_dim')\n",
"class _SmoothingWithDimProp(mx.operator.CustomOpProp):\n",
" def __init__(self, epsilon=0.1, axis=-1):\n",
" super(_SmoothingWithDimProp, self).__init__(True)\n",
" self._epsilon = float(epsilon)\n",
" self._axis = int(axis)\n",
"\n",
" def list_arguments(self):\n",
" return ['data']\n",
"\n",
" def list_outputs(self):\n",
" return ['output']\n",
"\n",
" def infer_shape(self, in_shape):\n",
" data_shape = in_shape[0]\n",
" output_shape = data_shape\n",
" return (data_shape,), (output_shape,), ()\n",
"\n",
" def declare_backward_dependency(self, out_grad, in_data, out_data):\n",
" return out_grad\n",
"\n",
" def create_operator(self, ctx, in_shapes, in_dtypes):\n",
" # create and return the CustomOp class.\n",
" return _SmoothingWithDim(self._epsilon, self._axis)\n",
"# pylint: enable=unused-argument\n",
"\n",
"\n",
"class LabelSmoothing(HybridBlock):\n",
" \"\"\"Applies label smoothing. See https://arxiv.org/abs/1512.00567.\n",
" Parameters\n",
" ----------\n",
" axis : int, default -1\n",
" The axis to smooth.\n",
" epsilon : float, default 0.1\n",
" The epsilon parameter in label smoothing\n",
" sparse_label : bool, default True\n",
" Whether input is an integer array instead of one hot array.\n",
" units : int or None\n",
" Vocabulary size. If units is not given, it will be inferred from the input.\n",
" prefix : str, default 'rnn_'\n",
" Prefix for name of `Block`s\n",
" (and name of weight if params is `None`).\n",
" params : Parameter or None\n",
" Container for weight sharing between cells.\n",
" Created if `None`.\n",
" \"\"\"\n",
" def __init__(self, axis=-1, epsilon=0.1, units=None,\n",
" sparse_label=True, prefix=None, params=None):\n",
" super(LabelSmoothing, self).__init__(prefix=prefix, params=params)\n",
" self._axis = axis\n",
" self._epsilon = epsilon\n",
" self._sparse_label = sparse_label\n",
" self._units = units\n",
"\n",
" def hybrid_forward(self, F, inputs, units=None): # pylint: disable=arguments-differ\n",
" \"\"\"\n",
" Parameters\n",
" ----------\n",
" F\n",
" inputs : Symbol or NDArray\n",
" Shape (batch_size, length) or (batch_size, length, V)\n",
" units : int or None\n",
" Returns\n",
" -------\n",
" smoothed_label : Symbol or NDArray\n",
" Shape (batch_size, length, V)\n",
" \"\"\"\n",
" if self._sparse_label:\n",
" assert units is not None or self._units is not None, \\\n",
" 'units needs to be given in function call or ' \\\n",
" 'instance initialization when sparse_label is False'\n",
" if units is None:\n",
" units = self._units\n",
" inputs = F.one_hot(inputs, depth=units)\n",
" if units is None and self._units is None:\n",
" return F.Custom(inputs, epsilon=self._epsilon, axis=self._axis,\n",
" op_type='_smoothing_with_dim')\n",
" else:\n",
" if units is None:\n",
" units = self._units\n",
" return ((1 - self._epsilon) * inputs) + (self._epsilon / units)"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {},
"outputs": [],
"source": [
"class Denoiser(gluon.nn.HybridBlock):\n",
" def __init__(self, alphabet_size, embed_size=64, max_src_length=FEATURE_LEN, max_tgt_length=FEATURE_LEN):\n",
" super(Denoiser, self).__init__()\n",
" encoder, decoder = get_transformer_encoder_decoder(max_src_length=max_src_length, max_tgt_length=max_tgt_length, units=embed_size, num_heads=8, num_layers=2)\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.src_embedding = gluon.nn.HybridSequential()\n",
" with self.src_embedding.name_scope():\n",
" self.src_embedding.add(gluon.nn.Embedding(input_dim=alphabet_size, output_dim=embed_size),\n",
" gluon.nn.Dropout(rate=0.))\n",
" self.tgt_embedding = gluon.nn.HybridSequential()\n",
" with self.tgt_embedding.name_scope():\n",
" self.tgt_embedding.add(gluon.nn.Embedding(input_dim=alphabet_size, output_dim=embed_size),\n",
" gluon.nn.Dropout(rate=0.))\n",
" self.tgt_proj = gluon.nn.Dense(units=alphabet_size, flatten=False,\n",
" prefix='tgt_proj_')\n",
" \n",
" def hybrid_forward(self, F, src_seq, tgt_seq, src_valid_length, tgt_valid_length):\n",
" encoder_outputs, encoder_additional_outputs = self.encode(src_seq, valid_length=src_valid_length)\n",
" decoder_states = self.decoder.init_state_from_encoder(encoder_outputs, encoder_valid_length=src_valid_length)\n",
" tgt_embed = self.tgt_embedding(tgt_seq)\n",
" outputs, _, _ = self.decoder.decode_seq(tgt_embed, decoder_states, tgt_valid_length)\n",
" outputs = self.tgt_proj(outputs)\n",
" return outputs\n",
" \n",
" def encode(self, inputs, states=None, valid_length=None):\n",
" return self.encoder(self.src_embedding(inputs), states, valid_length)\n",
" \n",
" def decode_step(self, step_input, states):\n",
" step_output, states, step_additional_outputs = self.decoder(self.tgt_embedding(step_input), states)\n",
" step_output = self.tgt_proj(step_output)\n",
" return step_output, states, step_additional_outputs\n",
" \n",
" def decode_logprob(self, step_input, states):\n",
" out, states, _ = self.decode_step(step_input, states)\n",
" return out, states"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {},
"outputs": [],
"source": [
"denoiser = Denoiser(alphabet_size=len(ALPHABET))\n",
"#denoiser.load_parameters('model_checkpoint/denoiser.params', ctx=ctx)"
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {},
"outputs": [],
"source": [
"denoiser.initialize(ctx=ctx)"
]
},
{
"cell_type": "code",
"execution_count": 152,
"metadata": {},
"outputs": [],
"source": [
"label_smoothing = LabelSmoothing(epsilon=0.1, units=len(ALPHABET))"
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [],
"source": [
"loss_function = SoftmaxCEMaskedLoss(sparse_label=False)"
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {},
"outputs": [],
"source": [
"trainer = gluon.Trainer(denoiser.collect_params(), 'adam', {'learning_rate':0.001})"
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [0], Loss 1.0904\n",
"[Typo ] inveted yet.'<EOS>\n",
"[Predicted] <BOS>insited yeta'<EOS>\n",
"[Correct ] <BOS>invited yet.'<EOS>\n",
"\n",
"Epoch [1], Loss 0.6172\n",
"[Typo ] whut vt miqht apear to others that what yu were or might have<EOS>\n",
"[Predicted] <BOS>what it might appear to others that what you were or might have<EOS>\n",
"[Correct ] <BOS>what it might appear to others that what you were or might have<EOS>\n",
"\n",
"Epoch [2], Loss 0.4213\n",
"[Typo ] `I'll fetch the executioner myself,' said tht King eagerly, and<EOS>\n",
"[Predicted] <BOS>`I'll fetch the executioner myself,' said the King eagerly, and<EOS>\n",
"[Correct ] <BOS>`I'll fetch the executioner myself,' said the King eagerly, and<EOS>\n",
"\n",
"Epoch [3], Loss 0.3381\n",
"[Typo ] fr instance, there's tha arch I've gt to go through next<EOS>\n",
"[Predicted] <BOS>for instance, there's the arch I've got to go through next<EOS>\n",
"[Correct ] <BOS>for instance, there's the arch I've got to go through next<EOS>\n",
"\n",
"Epoch [4], Loss 0.3070\n",
"[Typo ] `My uname is Alice, so please your Majesty,' said Alice very<EOS>\n",
"[Predicted] <BOS>`My name is Alice, so please your Majesty,' said Alice very<EOS>\n",
"[Correct ] <BOS>`My name is Alice, so please your Majesty,' said Alice very<EOS>\n",
"\n"
]
}
],
"source": [
"epochs = 5\n",
"for e in range(epochs):\n",
" loss = 0.\n",
" for i, (src, src_valid_length, tgt, tgt_valid_length, typo, label) in enumerate(train_data):\n",
" src = src.as_in_context(ctx)\n",
" tgt = tgt.as_in_context(ctx)\n",
" src_valid_length = src_valid_length.as_in_context(ctx).squeeze()\n",
" tgt_valid_length = tgt_valid_length.as_in_context(ctx).squeeze()\n",
" with autograd.record():\n",
" output = denoiser(src, tgt, src_valid_length, tgt_valid_length)\n",
" smoothed_label = label_smoothing(tgt)\n",
" ls = loss_function(output, smoothed_label, tgt_valid_length).mean()\n",
" ls.backward()\n",
" trainer.step(src.shape[0])\n",
" loss += ls.asscalar()\n",
" print(\"Epoch [{}], Loss {:.4f}\".format(e, loss/(i+1)))\n",
" print(\"[Typo ] {}\".format(decode(src[0].asnumpy())))\n",
" print(\"[Predicted] {}\".format(decode(output[0].asnumpy().argmax(axis=1))))\n",
" print(\"[Correct ] {}\".format(decode(tgt[0].asnumpy())))\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`My uname is Alice, so please your Majesty,' said Alice very\n"
]
}
],
"source": [
"src_seq = src[0:1]\n",
"src_valid_length = src_valid_length[0]\n",
"print(typo[0][0])"
]
},
{
"cell_type": "code",
"execution_count": 157,
"metadata": {},
"outputs": [],
"source": [
"encoder_outputs, _ = denoiser.encode(src_seq, valid_length=src_valid_length)\n",
"states = denoiser.decoder.init_state_from_encoder(encoder_outputs, encoder_valid_length=src_valid_length)\n",
"inputs = mx.nd.full(shape=(1,), ctx=src_seq.context, dtype=np.float32, val=BOS)\n",
"output = []"
]
},
{
"cell_type": "code",
"execution_count": 158,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS>\n"
]
}
],
"source": [
"for i in range(0,20):\n",
" output.append(int(inputs.asscalar()))\n",
" next_input, states = denoiser.decode_logprob(inputs, states)\n",
" inputs = next_input.argmax(axis=1)\n",
"output.append(int(inputs.asscalar()))\n",
"print(''.join([ALPHABET[int(c)] for c in output]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Manual Testing"
]
},
{
"cell_type": "code",
"execution_count": 159,
"metadata": {},
"outputs": [],
"source": [
" scorer = nlp.model.BeamSearchScorer(alpha=0, K=5, from_logits=False)"
]
},
{
"cell_type": "code",
"execution_count": 160,
"metadata": {},
"outputs": [],
"source": [
"eos_id = EOS\n",
"beam_sampler = nlp.model.BeamSearchSampler(beam_size=5,\n",
" decoder=denoiser.decode_logprob,\n",
" eos_id=eos_id,\n",
" scorer=scorer,\n",
" max_length=20)"
]
},
{
"cell_type": "code",
"execution_count": 161,
"metadata": {},
"outputs": [],
"source": [
"def generate_sequences(sampler, inputs, begin_states, num_print_outcomes):\n",
" samples, scores, valid_lengths = sampler(inputs, begin_states)\n",
" samples = samples[0].asnumpy()\n",
" scores = scores[0].asnumpy()\n",
" valid_lengths = valid_lengths[0].asnumpy()\n",
" print('Generation Result:')\n",
" for i in range(num_print_outcomes):\n",
" print(decode(samples[i][:valid_lengths[i]]), scores[i])"
]
},
{
"cell_type": "code",
"execution_count": 163,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation Result:\n",
"<BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><EOS> -5.713097\n",
"<BOS><BOS><BOS>eeeeeeeeeeeeeeeeee<EOS> -7.2956805\n",
"<BOS><BOS><BOS><BOS>eeeeeeeeeeeeeeeee<EOS> -7.3000755\n",
"<BOS><BOS><BOS><BOS><BOS>eeeeeeeeeeeeeeee<EOS> -7.317902\n",
"<BOS><BOS>eeeeeeeeeeeeeeeeeee<EOS> -7.362971\n"
]
}
],
"source": [
"states = denoiser.decoder.init_state_from_encoder(encoder_outputs, encoder_valid_length=src_valid_length)\n",
"inputs = mx.nd.full(shape=(1,), ctx=src_seq.context, dtype=np.float32, val=BOS)\n",
"generate_sequences(beam_sampler, inputs, states, 5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment