Last active
November 22, 2018 02:55
-
-
Save ThomasDelteil/4274c1946e958615df915d5d30908c19 to your computer and use it in GitHub Desktop.
TextDenoising
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
"""Encoder and decoder usded in sequence-to-sequence learning.""" | |
__all__ = ['TransformerEncoder', 'TransformerDecoder', | |
'get_transformer_encoder_decoder'] | |
import math | |
import os | |
import warnings | |
from functools import partial | |
from gluonnlp.model import AttentionCell, MLPAttentionCell, DotProductAttentionCell, MultiHeadAttentionCell | |
import gluonnlp as nlp | |
import mxnet as mx | |
from mxnet.gluon import rnn | |
from mxnet.gluon.block import Block | |
from mxnet import cpu, gluon | |
from mxnet.gluon import nn | |
from mxnet.gluon.block import HybridBlock | |
import numpy as np | |
def _list_bcast_where(F, mask, new_val_l, old_val_l): | |
"""Broadcast where. Implements out[i] = new_val[i] * mask + old_val[i] * (1 - mask) | |
Parameters | |
---------- | |
F : symbol or ndarray | |
mask : Symbol or NDArray | |
new_val_l : list of Symbols or list of NDArrays | |
old_val_l : list of Symbols or list of NDArrays | |
Returns | |
------- | |
out_l : list of Symbols or list of NDArrays | |
""" | |
return [F.broadcast_mul(new_val, mask) + F.broadcast_mul(old_val, 1 - mask) | |
for new_val, old_val in zip(new_val_l, old_val_l)] | |
def _get_cell_type(cell_type): | |
"""Get the object type of the cell by parsing the input | |
Parameters | |
---------- | |
cell_type : str or type | |
Returns | |
------- | |
cell_constructor: type | |
The constructor of the RNNCell | |
""" | |
if isinstance(cell_type, str): | |
if cell_type == 'lstm': | |
return rnn.LSTMCell | |
elif cell_type == 'gru': | |
return rnn.GRUCell | |
elif cell_type == 'relu_rnn': | |
return partial(rnn.RNNCell, activation='relu') | |
elif cell_type == 'tanh_rnn': | |
return partial(rnn.RNNCell, activation='tanh') | |
else: | |
raise NotImplementedError | |
else: | |
return cell_type | |
def _get_attention_cell(attention_cell, units=None, | |
scaled=True, num_heads=None, | |
use_bias=False, dropout=0.0): | |
""" | |
Parameters | |
---------- | |
attention_cell : AttentionCell or str | |
units : int or None | |
Returns | |
------- | |
attention_cell : AttentionCell | |
""" | |
if isinstance(attention_cell, str): | |
if attention_cell == 'scaled_luong': | |
return DotProductAttentionCell(units=units, scaled=True, normalized=False, | |
use_bias=use_bias, dropout=dropout, luong_style=True) | |
elif attention_cell == 'scaled_dot': | |
return DotProductAttentionCell(units=units, scaled=True, normalized=False, | |
use_bias=use_bias, dropout=dropout, luong_style=False) | |
elif attention_cell == 'dot': | |
return DotProductAttentionCell(units=units, scaled=False, normalized=False, | |
use_bias=use_bias, dropout=dropout, luong_style=False) | |
elif attention_cell == 'cosine': | |
return DotProductAttentionCell(units=units, scaled=False, use_bias=use_bias, | |
dropout=dropout, normalized=True) | |
elif attention_cell == 'mlp': | |
return MLPAttentionCell(units=units, normalized=False) | |
elif attention_cell == 'normed_mlp': | |
return MLPAttentionCell(units=units, normalized=True) | |
elif attention_cell == 'multi_head': | |
base_cell = DotProductAttentionCell(scaled=scaled, dropout=dropout) | |
return MultiHeadAttentionCell(base_cell=base_cell, query_units=units, use_bias=use_bias, | |
key_units=units, value_units=units, num_heads=num_heads) | |
else: | |
raise NotImplementedError | |
else: | |
assert isinstance(attention_cell, AttentionCell),\ | |
'attention_cell must be either string or AttentionCell. Received attention_cell={}'\ | |
.format(attention_cell) | |
return attention_cell | |
def _nested_sequence_last(data, valid_length): | |
""" | |
Parameters | |
---------- | |
data : nested container of NDArrays/Symbols | |
The input data. Each element will have shape (batch_size, ...) | |
valid_length : NDArray or Symbol | |
Valid length of the sequences. Shape (batch_size,) | |
Returns | |
------- | |
data_last: nested container of NDArrays/Symbols | |
The last valid element in the sequence. | |
""" | |
assert isinstance(data, list) | |
if isinstance(data[0], (mx.sym.Symbol, mx.nd.NDArray)): | |
F = mx.sym if isinstance(data[0], mx.sym.Symbol) else mx.ndarray | |
return F.SequenceLast(F.stack(*data, axis=0), | |
sequence_length=valid_length, | |
use_sequence_length=True) | |
elif isinstance(data[0], list): | |
ret = [] | |
for i in range(len(data[0])): | |
ret.append(_nested_sequence_last([ele[i] for ele in data], valid_length)) | |
return ret | |
else: | |
raise NotImplementedError | |
class Seq2SeqEncoder(Block): | |
r"""Base class of the encoders in sequence to sequence learning models. | |
""" | |
def __call__(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ | |
"""Encode the input sequence. | |
Parameters | |
---------- | |
inputs : NDArray | |
The input sequence, Shape (batch_size, length, C_in). | |
valid_length : NDArray or None, default None | |
The valid length of the input sequence, Shape (batch_size,). This is used when the | |
input sequences are padded. If set to None, all elements in the sequence are used. | |
states : list of NDArrays or None, default None | |
List that contains the initial states of the encoder. | |
Returns | |
------- | |
outputs : list | |
Outputs of the encoder. | |
""" | |
return super(Seq2SeqEncoder, self).__call__(inputs, valid_length, states) | |
def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ | |
raise NotImplementedError | |
class Seq2SeqDecoder(Block): | |
r"""Base class of the decoders in sequence to sequence learning models. | |
In the forward function, it generates the one-step-ahead decoding output. | |
""" | |
def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): | |
r"""Generates the initial decoder states based on the encoder outputs. | |
Parameters | |
---------- | |
encoder_outputs : list of NDArrays | |
encoder_valid_length : NDArray or None | |
Returns | |
------- | |
decoder_states : list | |
""" | |
raise NotImplementedError | |
def decode_seq(self, inputs, states, valid_length=None): | |
r"""Given the inputs and the context computed by the encoder, | |
generate the new states. This is usually used in the training phase where we set the inputs | |
to be the target sequence. | |
Parameters | |
---------- | |
inputs : NDArray | |
The input embeddings. Shape (batch_size, length, C_in) | |
states : list | |
The initial states of the decoder. | |
valid_length : NDArray or None | |
valid length of the inputs. Shape (batch_size,) | |
Returns | |
------- | |
output : NDArray | |
The output of the decoder. Shape is (batch_size, length, C_out) | |
states: list | |
The new states of the decoder | |
additional_outputs : list | |
Additional outputs of the decoder, e.g, the attention weights | |
""" | |
raise NotImplementedError | |
def __call__(self, step_input, states): #pylint: disable=arguments-differ | |
r"""One-step decoding of the input | |
Parameters | |
---------- | |
step_input : NDArray | |
Shape (batch_size, C_in) | |
states : list | |
The previous states of the decoder | |
Returns | |
------- | |
step_output : NDArray | |
Shape (batch_size, C_out) | |
states : list | |
step_additional_outputs : list | |
Additional outputs of the step, e.g, the attention weights | |
""" | |
return super(Seq2SeqDecoder, self).__call__(step_input, states) | |
def forward(self, step_input, states): #pylint: disable=arguments-differ | |
raise NotImplementedError | |
def _position_encoding_init(max_length, dim): | |
""" Init the sinusoid position encoding table """ | |
position_enc = np.arange(max_length).reshape((-1, 1)) \ | |
/ (np.power(10000, (2. / dim) * np.arange(dim).reshape((1, -1)))) | |
# Apply the cosine to even columns and sin to odds. | |
position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i | |
position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 | |
return position_enc | |
class PositionwiseFFN(HybridBlock): | |
"""Structure of the Positionwise Feed-Forward Neural Network. | |
Parameters | |
---------- | |
units : int | |
hidden_size : int | |
number of units in the hidden layer of position-wise feed-forward networks | |
dropout : float | |
use_residual : bool | |
weight_initializer : str or Initializer | |
Initializer for the input weights matrix, used for the linear | |
transformation of the inputs. | |
bias_initializer : str or Initializer | |
Initializer for the bias vector. | |
activation : str, default 'relu' | |
Activation function | |
prefix : str, default 'rnn_' | |
Prefix for name of `Block`s | |
(and name of weight if params is `None`). | |
params : Parameter or None | |
Container for weight sharing between cells. | |
Created if `None`. | |
""" | |
def __init__(self, units=512, hidden_size=2048, dropout=0.0, use_residual=True, | |
weight_initializer=None, bias_initializer='zeros', activation='relu', | |
prefix=None, params=None): | |
super(PositionwiseFFN, self).__init__(prefix=prefix, params=params) | |
self._hidden_size = hidden_size | |
self._units = units | |
self._use_residual = use_residual | |
with self.name_scope(): | |
self.ffn_1 = nn.Dense(units=hidden_size, flatten=False, | |
activation=activation, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
prefix='ffn_1_') | |
self.ffn_2 = nn.Dense(units=units, flatten=False, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
prefix='ffn_2_') | |
self.dropout_layer = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm() | |
def hybrid_forward(self, F, inputs): # pylint: disable=arguments-differ | |
# pylint: disable=unused-argument | |
"""Position-wise encoding of the inputs. | |
Parameters | |
---------- | |
inputs : Symbol or NDArray | |
Input sequence. Shape (batch_size, length, C_in) | |
Returns | |
------- | |
outputs : Symbol or NDArray | |
Shape (batch_size, length, C_out) | |
""" | |
outputs = self.ffn_1(inputs) | |
outputs = self.ffn_2(outputs) | |
outputs = self.dropout_layer(outputs) | |
if self._use_residual: | |
outputs = outputs + inputs | |
outputs = self.layer_norm(outputs) | |
return outputs | |
class TransformerEncoderCell(HybridBlock): | |
"""Structure of the Transformer Encoder Cell. | |
Parameters | |
---------- | |
attention_cell : AttentionCell or str, default 'multi_head' | |
Arguments of the attention cell. | |
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' | |
units : int | |
hidden_size : int | |
number of units in the hidden layer of position-wise feed-forward networks | |
num_heads : int | |
Number of heads in multi-head attention | |
scaled : bool | |
Whether to scale the softmax input by the sqrt of the input dimension | |
in multi-head attention | |
dropout : float | |
use_residual : bool | |
output_attention: bool | |
Whether to output the attention weights | |
weight_initializer : str or Initializer | |
Initializer for the input weights matrix, used for the linear | |
transformation of the inputs. | |
bias_initializer : str or Initializer | |
Initializer for the bias vector. | |
prefix : str, default 'rnn_' | |
Prefix for name of `Block`s | |
(and name of weight if params is `None`). | |
params : Parameter or None | |
Container for weight sharing between cells. | |
Created if `None`. | |
""" | |
def __init__(self, attention_cell='multi_head', units=128, | |
hidden_size=512, num_heads=4, scaled=True, | |
dropout=0.0, use_residual=True, output_attention=False, | |
weight_initializer=None, bias_initializer='zeros', | |
prefix=None, params=None): | |
super(TransformerEncoderCell, self).__init__(prefix=prefix, params=params) | |
self._units = units | |
self._num_heads = num_heads | |
self._dropout = dropout | |
self._use_residual = use_residual | |
self._output_attention = output_attention | |
with self.name_scope(): | |
self.dropout_layer = nn.Dropout(dropout) | |
self.attention_cell = _get_attention_cell(attention_cell, | |
units=units, | |
num_heads=num_heads, | |
scaled=scaled, | |
dropout=dropout) | |
self.proj = nn.Dense(units=units, flatten=False, use_bias=False, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
prefix='proj_') | |
self.ffn = PositionwiseFFN(hidden_size=hidden_size, units=units, | |
use_residual=use_residual, dropout=dropout, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer) | |
self.layer_norm = nn.LayerNorm() | |
def hybrid_forward(self, F, inputs, mask=None): # pylint: disable=arguments-differ | |
# pylint: disable=unused-argument | |
"""Transformer Encoder Attention Cell. | |
Parameters | |
---------- | |
inputs : Symbol or NDArray | |
Input sequence. Shape (batch_size, length, C_in) | |
mask : Symbol or NDArray or None | |
Mask for inputs. Shape (batch_size, length, length) | |
Returns | |
------- | |
encoder_cell_outputs: list | |
Outputs of the encoder cell. Contains: | |
- outputs of the transformer encoder cell. Shape (batch_size, length, C_out) | |
- additional_outputs of all the transformer encoder cell | |
""" | |
outputs, attention_weights =\ | |
self.attention_cell(inputs, inputs, inputs, mask) | |
outputs = self.proj(outputs) | |
outputs = self.dropout_layer(outputs) | |
if self._use_residual: | |
outputs = outputs + inputs | |
outputs = self.layer_norm(outputs) | |
outputs = self.ffn(outputs) | |
additional_outputs = [] | |
if self._output_attention: | |
additional_outputs.append(attention_weights) | |
return outputs, additional_outputs | |
class TransformerDecoderCell(HybridBlock): | |
"""Structure of the Transformer Decoder Cell. | |
Parameters | |
---------- | |
attention_cell : AttentionCell or str, default 'multi_head' | |
Arguments of the attention cell. | |
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' | |
units : int | |
hidden_size : int | |
number of units in the hidden layer of position-wise feed-forward networks | |
num_heads : int | |
Number of heads in multi-head attention | |
scaled : bool | |
Whether to scale the softmax input by the sqrt of the input dimension | |
in multi-head attention | |
dropout : float | |
use_residual : bool | |
output_attention: bool | |
Whether to output the attention weights | |
weight_initializer : str or Initializer | |
Initializer for the input weights matrix, used for the linear | |
transformation of the inputs. | |
bias_initializer : str or Initializer | |
Initializer for the bias vector. | |
prefix : str, default 'rnn_' | |
Prefix for name of `Block`s | |
(and name of weight if params is `None`). | |
params : Parameter or None | |
Container for weight sharing between cells. | |
Created if `None`. | |
""" | |
def __init__(self, attention_cell='multi_head', units=128, | |
hidden_size=512, num_heads=4, scaled=True, | |
dropout=0.0, use_residual=True, output_attention=False, | |
weight_initializer=None, bias_initializer='zeros', | |
prefix=None, params=None): | |
super(TransformerDecoderCell, self).__init__(prefix=prefix, params=params) | |
self._units = units | |
self._num_heads = num_heads | |
self._dropout = dropout | |
self._use_residual = use_residual | |
self._output_attention = output_attention | |
self._scaled = scaled | |
with self.name_scope(): | |
self.dropout_layer = nn.Dropout(dropout) | |
self.attention_cell_in = _get_attention_cell(attention_cell, | |
units=units, | |
num_heads=num_heads, | |
scaled=scaled, | |
dropout=dropout) | |
self.attention_cell_inter = _get_attention_cell(attention_cell, | |
units=units, | |
num_heads=num_heads, | |
scaled=scaled, | |
dropout=dropout) | |
self.proj_in = nn.Dense(units=units, flatten=False, | |
use_bias=False, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
prefix='proj_in_') | |
self.proj_inter = nn.Dense(units=units, flatten=False, | |
use_bias=False, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
prefix='proj_inter_') | |
self.ffn = PositionwiseFFN(hidden_size=hidden_size, | |
units=units, | |
use_residual=use_residual, | |
dropout=dropout, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer) | |
self.layer_norm_in = nn.LayerNorm() | |
self.layer_norm_inter = nn.LayerNorm() | |
def hybrid_forward(self, F, inputs, mem_value, mask=None, mem_mask=None): #pylint: disable=unused-argument | |
# pylint: disable=arguments-differ | |
"""Transformer Decoder Attention Cell. | |
Parameters | |
---------- | |
inputs : Symbol or NDArray | |
Input sequence. Shape (batch_size, length, C_in) | |
mem_value : Symbol or NDArrays | |
Memory value, i.e. output of the encoder. Shape (batch_size, mem_length, C_in) | |
mask : Symbol or NDArray or None | |
Mask for inputs. Shape (batch_size, length, length) | |
mem_mask : Symbol or NDArray or None | |
Mask for mem_value. Shape (batch_size, length, mem_length) | |
Returns | |
------- | |
decoder_cell_outputs: list | |
Outputs of the decoder cell. Contains: | |
- outputs of the transformer decoder cell. Shape (batch_size, length, C_out) | |
- additional_outputs of all the transformer decoder cell | |
""" | |
outputs, attention_in_outputs =\ | |
self.attention_cell_in(inputs, inputs, inputs, mask) | |
outputs = self.proj_in(outputs) | |
outputs = self.dropout_layer(outputs) | |
if self._use_residual: | |
outputs = outputs + inputs | |
outputs = self.layer_norm_in(outputs) | |
inputs = outputs | |
outputs, attention_inter_outputs = \ | |
self.attention_cell_inter(inputs, mem_value, mem_value, mem_mask) | |
outputs = self.proj_inter(outputs) | |
outputs = self.dropout_layer(outputs) | |
if self._use_residual: | |
outputs = outputs + inputs | |
outputs = self.layer_norm_inter(outputs) | |
outputs = self.ffn(outputs) | |
additional_outputs = [] | |
if self._output_attention: | |
additional_outputs.append(attention_in_outputs) | |
additional_outputs.append(attention_inter_outputs) | |
return outputs, additional_outputs | |
class TransformerEncoder(HybridBlock, Seq2SeqEncoder): | |
"""Structure of the Transformer Encoder. | |
Parameters | |
---------- | |
attention_cell : AttentionCell or str, default 'multi_head' | |
Arguments of the attention cell. | |
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' | |
num_layers : int | |
units : int | |
hidden_size : int | |
number of units in the hidden layer of position-wise feed-forward networks | |
max_length : int | |
Maximum length of the input sequence | |
num_heads : int | |
Number of heads in multi-head attention | |
scaled : bool | |
Whether to scale the softmax input by the sqrt of the input dimension | |
in multi-head attention | |
dropout : float | |
use_residual : bool | |
output_attention: bool | |
Whether to output the attention weights | |
weight_initializer : str or Initializer | |
Initializer for the input weights matrix, used for the linear | |
transformation of the inputs. | |
bias_initializer : str or Initializer | |
Initializer for the bias vector. | |
prefix : str, default 'rnn_' | |
Prefix for name of `Block`s | |
(and name of weight if params is `None`). | |
params : Parameter or None | |
Container for weight sharing between cells. | |
Created if `None`. | |
""" | |
def __init__(self, attention_cell='multi_head', num_layers=2, | |
units=512, hidden_size=2048, max_length=50, | |
num_heads=4, scaled=True, dropout=0.0, | |
use_residual=True, output_attention=False, | |
weight_initializer=None, bias_initializer='zeros', | |
prefix=None, params=None): | |
super(TransformerEncoder, self).__init__(prefix=prefix, params=params) | |
assert units % num_heads == 0,\ | |
'In TransformerEncoder, The units should be divided exactly ' \ | |
'by the number of heads. Received units={}, num_heads={}' \ | |
.format(units, num_heads) | |
self._num_layers = num_layers | |
self._max_length = max_length | |
self._num_heads = num_heads | |
self._units = units | |
self._hidden_size = hidden_size | |
self._output_attention = output_attention | |
self._dropout = dropout | |
self._use_residual = use_residual | |
self._scaled = scaled | |
with self.name_scope(): | |
self.dropout_layer = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm() | |
self.position_weight = self.params.get_constant('const', | |
_position_encoding_init(max_length, | |
units)) | |
self.transformer_cells = nn.HybridSequential() | |
for i in range(num_layers): | |
self.transformer_cells.add( | |
TransformerEncoderCell( | |
units=units, | |
hidden_size=hidden_size, | |
num_heads=num_heads, | |
attention_cell=attention_cell, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
dropout=dropout, | |
use_residual=use_residual, | |
scaled=scaled, | |
output_attention=output_attention, | |
prefix='transformer%d_' % i)) | |
def __call__(self, inputs, states=None, valid_length=None): #pylint: disable=arguments-differ | |
"""Encoder the inputs given the states and valid sequence length. | |
Parameters | |
---------- | |
inputs : NDArray | |
Input sequence. Shape (batch_size, length, C_in) | |
states : list of NDArrays or None | |
Initial states. The list of initial states and masks | |
valid_length : NDArray or None | |
Valid lengths of each sequence. This is usually used when part of sequence has | |
been padded. Shape (batch_size,) | |
Returns | |
------- | |
encoder_outputs: list | |
Outputs of the encoder. Contains: | |
- outputs of the transformer encoder. Shape (batch_size, length, C_out) | |
- additional_outputs of all the transformer encoder | |
""" | |
return super(TransformerEncoder, self).__call__(inputs, states, valid_length) | |
def forward(self, inputs, states=None, valid_length=None, steps=None): # pylint: disable=arguments-differ | |
""" | |
Parameters | |
---------- | |
inputs : NDArray, Shape(batch_size, length, C_in) | |
states : list of NDArray | |
valid_length : NDArray | |
steps : NDArray | |
Stores value [0, 1, ..., length]. | |
It is used for lookup in positional encoding matrix | |
Returns | |
------- | |
outputs : NDArray | |
The output of the encoder. Shape is (batch_size, length, C_out) | |
additional_outputs : list | |
Either be an empty list or contains the attention weights in this step. | |
The attention weights will have shape (batch_size, length, length) or | |
(batch_size, num_heads, length, length) | |
""" | |
length = inputs.shape[1] | |
if valid_length is not None: | |
mask = mx.nd.broadcast_lesser( | |
mx.nd.arange(length, ctx=valid_length.context).reshape((1, -1)), | |
valid_length.reshape((-1, 1))) | |
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=1), axis=1, size=length) | |
if states is None: | |
states = [mask] | |
else: | |
states.append(mask) | |
inputs = inputs * math.sqrt(inputs.shape[-1]) | |
steps = mx.nd.arange(length, ctx=inputs.context) | |
if states is None: | |
states = [steps] | |
else: | |
states.append(steps) | |
if valid_length is not None: | |
step_output, additional_outputs =\ | |
super(TransformerEncoder, self).forward(inputs, states, valid_length) | |
else: | |
step_output, additional_outputs =\ | |
super(TransformerEncoder, self).forward(inputs, states) | |
return step_output, additional_outputs | |
def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_weight=None): # pylint: disable=arguments-differ | |
""" | |
Parameters | |
---------- | |
inputs : NDArray or Symbol, Shape(batch_size, length, C_in) | |
states : list of NDArray or Symbol | |
valid_length : NDArray or Symbol | |
position_weight : NDArray or Symbol | |
Returns | |
------- | |
outputs : NDArray or Symbol | |
The output of the encoder. Shape is (batch_size, length, C_out) | |
additional_outputs : list | |
Either be an empty list or contains the attention weights in this step. | |
The attention weights will have shape (batch_size, length, length) or | |
(batch_size, num_heads, length, length) | |
""" | |
if states is not None: | |
steps = states[-1] | |
# Positional Encoding | |
inputs = F.broadcast_add(inputs, F.expand_dims(F.Embedding(steps, position_weight, | |
self._max_length, | |
self._units), axis=0)) | |
inputs = self.dropout_layer(inputs) | |
inputs = self.layer_norm(inputs) | |
outputs = inputs | |
if valid_length is not None: | |
mask = states[-2] | |
else: | |
mask = None | |
additional_outputs = [] | |
for cell in self.transformer_cells: | |
outputs, attention_weights = cell(inputs, mask) | |
inputs = outputs | |
if self._output_attention: | |
additional_outputs.append(attention_weights) | |
if valid_length is not None: | |
outputs = F.SequenceMask(outputs, sequence_length=valid_length, | |
use_sequence_length=True, axis=1) | |
return outputs, additional_outputs | |
class TransformerDecoder(HybridBlock, Seq2SeqDecoder): | |
"""Structure of the Transformer Decoder. | |
Parameters | |
---------- | |
attention_cell : AttentionCell or str, default 'multi_head' | |
Arguments of the attention cell. | |
Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' | |
num_layers : int | |
units : int | |
hidden_size : int | |
number of units in the hidden layer of position-wise feed-forward networks | |
max_length : int | |
Maximum length of the input sequence. This is used for constructing position encoding | |
num_heads : int | |
Number of heads in multi-head attention | |
scaled : bool | |
Whether to scale the softmax input by the sqrt of the input dimension | |
in multi-head attention | |
dropout : float | |
use_residual : bool | |
output_attention: bool | |
Whether to output the attention weights | |
weight_initializer : str or Initializer | |
Initializer for the input weights matrix, used for the linear | |
transformation of the inputs. | |
bias_initializer : str or Initializer | |
Initializer for the bias vector. | |
prefix : str, default 'rnn_' | |
Prefix for name of `Block`s | |
(and name of weight if params is `None`). | |
params : Parameter or None | |
Container for weight sharing between cells. | |
Created if `None`. | |
""" | |
def __init__(self, attention_cell='multi_head', num_layers=2, | |
units=128, hidden_size=2048, max_length=50, | |
num_heads=4, scaled=True, dropout=0.0, | |
use_residual=True, output_attention=False, | |
weight_initializer=None, bias_initializer='zeros', | |
prefix=None, params=None): | |
super(TransformerDecoder, self).__init__(prefix=prefix, params=params) | |
assert units % num_heads == 0, 'In TransformerDecoder, the units should be divided ' \ | |
'exactly by the number of heads. Received units={}, ' \ | |
'num_heads={}'.format(units, num_heads) | |
self._num_layers = num_layers | |
self._units = units | |
self._hidden_size = hidden_size | |
self._num_states = num_heads | |
self._max_length = max_length | |
self._dropout = dropout | |
self._use_residual = use_residual | |
self._output_attention = output_attention | |
self._scaled = scaled | |
with self.name_scope(): | |
self.dropout_layer = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm() | |
self.position_weight = self.params.get_constant('const', | |
_position_encoding_init(max_length, | |
units)) | |
self.transformer_cells = nn.HybridSequential() | |
for i in range(num_layers): | |
self.transformer_cells.add( | |
TransformerDecoderCell( | |
units=units, | |
hidden_size=hidden_size, | |
num_heads=num_heads, | |
attention_cell=attention_cell, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
dropout=dropout, | |
scaled=scaled, | |
use_residual=use_residual, | |
output_attention=output_attention, | |
prefix='transformer%d_' % i)) | |
def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): | |
"""Initialize the state from the encoder outputs. | |
Parameters | |
---------- | |
encoder_outputs : list | |
encoder_valid_length : NDArray or None | |
Returns | |
------- | |
decoder_states : list | |
The decoder states, includes: | |
- mem_value : NDArray | |
- mem_masks : NDArray, optional | |
""" | |
mem_value = encoder_outputs | |
decoder_states = [mem_value] | |
mem_length = mem_value.shape[1] | |
if encoder_valid_length is not None: | |
mem_masks = mx.nd.broadcast_lesser( | |
mx.nd.arange(mem_length, ctx=encoder_valid_length.context).reshape((1, -1)), | |
encoder_valid_length.reshape((-1, 1))) | |
decoder_states.append(mem_masks) | |
self._encoder_valid_length = encoder_valid_length | |
return decoder_states | |
def decode_seq(self, inputs, states, valid_length=None): | |
"""Decode the decoder inputs. This function is only used for training. | |
Parameters | |
---------- | |
inputs : NDArray, Shape (batch_size, length, C_in) | |
states : list of NDArrays or None | |
Initial states. The list of decoder states | |
valid_length : NDArray or None | |
Valid lengths of each sequence. This is usually used when part of sequence has | |
been padded. Shape (batch_size,) | |
Returns | |
------- | |
output : NDArray, Shape (batch_size, length, C_out) | |
states : list | |
The decoder states, includes: | |
- mem_value : NDArray | |
- mem_masks : NDArray, optional | |
additional_outputs : list of list | |
Either be an empty list or contains the attention weights in this step. | |
The attention weights will have shape (batch_size, length, mem_length) or | |
(batch_size, num_heads, length, mem_length) | |
""" | |
batch_size = inputs.shape[0] | |
length = inputs.shape[1] | |
length_array = mx.nd.arange(length, ctx=inputs.context) | |
mask = mx.nd.broadcast_lesser_equal( | |
length_array.reshape((1, -1)), | |
length_array.reshape((-1, 1))) | |
if valid_length is not None: | |
batch_mask = mx.nd.broadcast_lesser( | |
mx.nd.arange(length, ctx=valid_length.context).reshape((1, -1)), | |
valid_length.reshape((-1, 1))) | |
mask = mx.nd.broadcast_mul(mx.nd.expand_dims(batch_mask, -1), | |
mx.nd.expand_dims(mask, 0)) | |
else: | |
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=batch_size) | |
states = [None] + states | |
output, states, additional_outputs = self.forward(inputs, states, mask) | |
states = states[1:] | |
if valid_length is not None: | |
output = mx.nd.SequenceMask(output, | |
sequence_length=valid_length, | |
use_sequence_length=True, | |
axis=1) | |
return output, states, additional_outputs | |
def __call__(self, step_input, states): #pylint: disable=arguments-differ | |
"""One-step-ahead decoding of the Transformer decoder. | |
Parameters | |
---------- | |
step_input : NDArray | |
states : list of NDArray | |
Returns | |
------- | |
step_output : NDArray | |
The output of the decoder. | |
In the train mode, Shape is (batch_size, length, C_out) | |
In the test mode, Shape is (batch_size, C_out) | |
new_states: list | |
Includes | |
- last_embeds : NDArray or None | |
It is only given during testing | |
- mem_value : NDArray | |
- mem_masks : NDArray, optional | |
step_additional_outputs : list of list | |
Either be an empty list or contains the attention weights in this step. | |
The attention weights will have shape (batch_size, length, mem_length) or | |
(batch_size, num_heads, length, mem_length) | |
""" | |
return super(TransformerDecoder, self).__call__(step_input, states) | |
def forward(self, step_input, states, mask=None): #pylint: disable=arguments-differ, missing-docstring | |
input_shape = step_input.shape | |
mem_mask = None | |
# If it is in testing, transform input tensor to a tensor with shape NTC | |
# Otherwise remove the None in states. | |
if len(input_shape) == 2: | |
if self._encoder_valid_length is not None: | |
has_last_embeds = len(states) == 3 | |
else: | |
has_last_embeds = len(states) == 2 | |
if has_last_embeds: | |
last_embeds = states[0] | |
step_input = mx.nd.concat(last_embeds, | |
mx.nd.expand_dims(step_input, axis=1), | |
dim=1) | |
states = states[1:] | |
else: | |
step_input = mx.nd.expand_dims(step_input, axis=1) | |
elif states[0] is None: | |
states = states[1:] | |
has_mem_mask = (len(states) == 2) | |
if has_mem_mask: | |
_, mem_mask = states | |
augmented_mem_mask = mx.nd.expand_dims(mem_mask, axis=1)\ | |
.broadcast_axes(axis=1, size=step_input.shape[1]) | |
states[-1] = augmented_mem_mask | |
if mask is None: | |
length_array = mx.nd.arange(step_input.shape[1], ctx=step_input.context) | |
mask = mx.nd.broadcast_lesser_equal( | |
length_array.reshape((1, -1)), | |
length_array.reshape((-1, 1))) | |
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), | |
axis=0, size=step_input.shape[0]) | |
steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context) | |
states.append(steps) | |
step_output, step_additional_outputs = \ | |
super(TransformerDecoder, self).forward(step_input * math.sqrt(step_input.shape[-1]), #pylint: disable=too-many-function-args | |
states, mask) | |
states = states[:-1] | |
if has_mem_mask: | |
states[-1] = mem_mask | |
new_states = [step_input] + states | |
# If it is in testing, only output the last one | |
if len(input_shape) == 2: | |
step_output = step_output[:, -1, :] | |
return step_output, new_states, step_additional_outputs | |
def hybrid_forward(self, F, step_input, states, mask=None, position_weight=None): #pylint: disable=arguments-differ | |
""" | |
Parameters | |
---------- | |
step_input : NDArray or Symbol, Shape (batch_size, length, C_in) | |
states : list of NDArray or Symbol | |
mask : NDArray or Symbol | |
position_weight : NDArray or Symbol | |
Returns | |
------- | |
step_output : NDArray or Symbol | |
The output of the decoder. Shape is (batch_size, length, C_out) | |
step_additional_outputs : list | |
Either be an empty list or contains the attention weights in this step. | |
The attention weights will have shape (batch_size, length, mem_length) or | |
(batch_size, num_heads, length, mem_length) | |
""" | |
has_mem_mask = (len(states) == 3) | |
if has_mem_mask: | |
mem_value, mem_mask, steps = states | |
else: | |
mem_value, steps = states | |
mem_mask = None | |
# Positional Encoding | |
step_input = F.broadcast_add(step_input, | |
F.expand_dims(F.Embedding(steps, | |
position_weight, | |
self._max_length, | |
self._units), | |
axis=0)) | |
step_input = self.dropout_layer(step_input) | |
step_input = self.layer_norm(step_input) | |
inputs = step_input | |
outputs = inputs | |
step_additional_outputs = [] | |
attention_weights_l = [] | |
for cell in self.transformer_cells: | |
outputs, attention_weights = cell(inputs, mem_value, mask, mem_mask) | |
if self._output_attention: | |
attention_weights_l.append(attention_weights) | |
inputs = outputs | |
if self._output_attention: | |
step_additional_outputs.extend(attention_weights_l) | |
return outputs, step_additional_outputs | |
def get_transformer_encoder_decoder(num_layers=2, | |
num_heads=8, scaled=True, | |
units=512, hidden_size=2048, dropout=0.0, use_residual=True, | |
max_src_length=50, max_tgt_length=50, | |
weight_initializer=None, bias_initializer='zeros', | |
prefix='transformer_', params=None): | |
"""Build a pair of Parallel GNMT encoder/decoder | |
Parameters | |
---------- | |
num_layers : int | |
num_heads : int | |
scaled : bool | |
units : int | |
hidden_size : int | |
dropout : float | |
use_residual : bool | |
max_src_length : int | |
max_tgt_length : int | |
weight_initializer : mx.init.Initializer or None | |
bias_initializer : mx.init.Initializer or None | |
prefix : str, default 'transformer_' | |
Prefix for name of `Block`s. | |
params : Parameter or None | |
Container for weight sharing between layers. | |
Created if `None`. | |
Returns | |
------- | |
encoder : TransformerEncoder | |
decoder :TransformerDecoder | |
""" | |
encoder = TransformerEncoder(num_layers=num_layers, | |
num_heads=num_heads, | |
max_length=max_src_length, | |
units=units, | |
hidden_size=hidden_size, | |
dropout=dropout, | |
scaled=scaled, | |
use_residual=use_residual, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
prefix=prefix + 'enc_', params=params) | |
decoder = TransformerDecoder(num_layers=num_layers, | |
num_heads=num_heads, | |
max_length=max_tgt_length, | |
units=units, | |
hidden_size=hidden_size, | |
dropout=dropout, | |
scaled=scaled, | |
use_residual=use_residual, | |
weight_initializer=weight_initializer, | |
bias_initializer=bias_initializer, | |
prefix=prefix + 'dec_', params=params) | |
return encoder, decoder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Text Denoising\n", | |
"\n", | |
"Inspired by \"Neural Networks for Text Correction and Completion in Keyboard Decoding\" by Shaona Ghosh and Per Ola Kristensson. https://arxiv.org/pdf/1709.06429.pdf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 146, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from collections import defaultdict\n", | |
"import os\n", | |
"import random\n", | |
"import string" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from gluoncv.data.batchify import Tuple, Stack, Append\n", | |
"import mxnet as mx\n", | |
"from mxnet import gluon, autograd\n", | |
"from mxnet.gluon import HybridBlock\n", | |
"from mxnet.gluon.loss import SoftmaxCELoss\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"gpu(0)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()\n", | |
"ctx" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 148, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if not os.path.isdir('dataset'):\n", | |
" os.makedirs('dataset')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'dataset/typo/alicewonder.txt'" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"typo_filepath = 'dataset/typo/typo-corpus-r1.txt'\n", | |
"text_filepath = 'dataset/typo/alicewonder.txt'\n", | |
"mx.test_utils.download('http://luululu.com/tweet/typo-corpus-r1.txt', dirname='dataset/typo')\n", | |
"mx.test_utils.download('http://textfiles.com/etext/FICTION/alicewonder.txt', dirname='dataset/typo')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ALPHABET = ['<PAD>', '<BOS>', '<EOS>']+list(' ' + string.ascii_letters + string.digits + string.punctuation)\n", | |
"ALPHABET_INDEX = {letter: index for index, letter in enumerate(ALPHABET)} # { a: 0, b: 1, etc}\n", | |
"FEATURE_LEN = 150 # max-length in characters for one document\n", | |
"NUM_WORKERS = 8 # number of workers used in the data loading\n", | |
"BATCH_SIZE = 128 # number of documents per batch\n", | |
"PAD = 0\n", | |
"BOS = 1\n", | |
"EOS = 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class NoisyTextDataset(mx.gluon.data.Dataset):\n", | |
" def __init__(self, text_filepath, typo_filepath, replace_proba=0.2, is_train=True):\n", | |
" self.replace_proba = replace_proba\n", | |
" self.typo_dict = self._process_typo(typo_filepath)\n", | |
" self.text = self._process_text(text_filepath, is_train)\n", | |
" \n", | |
" def _process_text(self, filename, is_train):\n", | |
" with open(text_filepath, 'r') as f:\n", | |
" text = []\n", | |
" for line in f.readlines():\n", | |
" line = line.replace('\\n', '').strip()\n", | |
" if line != '':\n", | |
" text.append(line)\n", | |
" \n", | |
" split_index = int(0.8*len(text))\n", | |
" if is_train:\n", | |
" text = text[:split_index]\n", | |
" else:\n", | |
" text = text[split_index:]\n", | |
" return text\n", | |
"\n", | |
" \n", | |
" def _process_typo(self, filename):\n", | |
" \"\"\"\n", | |
" This function loads the typo dataset and generate the \n", | |
" probability distribution of typos for each valid word\n", | |
" \"\"\"\n", | |
" typo_dict = defaultdict(lambda : defaultdict(float))\n", | |
" with open(filename, 'r') as f:\n", | |
" lines = f.readlines()\n", | |
" for line in lines:\n", | |
" typo, correct = line.split('\\t')[0:2]\n", | |
" typo_dict[correct][typo] += 1\n", | |
" for _, correct_word in typo_dict.items():\n", | |
" total = 0\n", | |
" for _, count in correct_word.items():\n", | |
" total += count\n", | |
" previous_value = 0.\n", | |
" for wrong_word in correct_word:\n", | |
" correct_word[wrong_word] = correct_word[wrong_word] / total + previous_value\n", | |
" previous_value = correct_word[wrong_word]\n", | |
" return typo_dict\n", | |
" \n", | |
" def _transform_line(self, line):\n", | |
" \"\"\"\n", | |
" replace words that are in the typo dataset with a typo\n", | |
" with a probability `self.replace_proba`\n", | |
" \"\"\"\n", | |
" output = []\n", | |
" proba_replace = 0.2\n", | |
" for word in self._pre_process_line(line):\n", | |
" if word.lower() in self.typo_dict and word != '':\n", | |
" if random.random() < self.replace_proba:\n", | |
" draw = random.random()\n", | |
" for typo, value in self.typo_dict[word].items():\n", | |
" if draw < value:\n", | |
" word = self._match_caps(word, typo)\n", | |
" break\n", | |
" output.append(word)\n", | |
" else:\n", | |
" output.append(word)\n", | |
"\n", | |
" return self._post_process_line(output)\n", | |
" \n", | |
" def _pre_process_line(self, line):\n", | |
" line = line.replace('\\n', '')\n", | |
" for char in string.punctuation:\n", | |
" if char in line:\n", | |
" line = line.replace(char, ' '+char+' ')\n", | |
" return line.split(' ')\n", | |
" \n", | |
" def _post_process_line(self, words):\n", | |
" output = ' '.join(words)\n", | |
" for char in string.punctuation:\n", | |
" output = output.replace(' '+char+' ', char)\n", | |
" return output\n", | |
" \n", | |
" def _match_caps(self, original, typo):\n", | |
" if original.isupper():\n", | |
" return typo.upper()\n", | |
" elif original.istitle():\n", | |
" return typo.capitalize()\n", | |
" else:\n", | |
" return typo\n", | |
" \n", | |
" def __getitem__(self, idx):\n", | |
" line = self.text[idx]\n", | |
" line_typo = self._transform_line(line)\n", | |
" return line_typo, line\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.text)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def encode(text, src=True):\n", | |
" encoded = np.ones(FEATURE_LEN, dtype='float32') * PAD\n", | |
" text = text[:FEATURE_LEN-2]\n", | |
" i = 0\n", | |
" if not src:\n", | |
" encoded[0] = BOS\n", | |
" i = 1\n", | |
" for letter in text:\n", | |
" if letter in ALPHABET_INDEX:\n", | |
" encoded[i] = ALPHABET_INDEX[letter]\n", | |
" i += 1\n", | |
" encoded[i] = EOS\n", | |
" return encoded, np.array([len(text)+2]).astype('float32')\n", | |
"\n", | |
"def transform(data, label):\n", | |
" src, src_valid_length = encode(data, src=True)\n", | |
" tgt, tgt_valid_length = encode(label, src=False)\n", | |
" return src, src_valid_length, tgt, tgt_valid_length, data, label" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset_train = NoisyTextDataset(text_filepath=text_filepath, typo_filepath=typo_filepath, replace_proba=0.4, is_train=True).transform(transform)\n", | |
"dataset_test = NoisyTextDataset(text_filepath=text_filepath, typo_filepath=typo_filepath, replace_proba=0.4, is_train=False).transform(transform)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([16., 11., 3., 7., 8., 4., 21., 77., 72., 3., 22., 4., 12.,\n", | |
" 7., 7., 3., 30., 15., 12., 6., 8., 77., 3., 22., 8., 21.,\n", | |
" 12., 18., 22., 24., 15., 28., 77., 3., 93., 38., 72., 15., 15.,\n", | |
" 3., 11., 4., 25., 8., 3., 17., 18., 23., 11., 12., 10., 17.,\n", | |
" 3., 16., 18., 21., 8., 3., 23., 18., 3., 7., 18., 12., 2.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
" array([66.], dtype=float32),\n", | |
" array([ 1., 16., 28., 3., 7., 8., 4., 21., 77., 72., 3., 22., 4.,\n", | |
" 12., 7., 3., 30., 15., 12., 6., 8., 77., 3., 22., 8., 21.,\n", | |
" 12., 18., 24., 22., 15., 28., 77., 3., 93., 38., 72., 15., 15.,\n", | |
" 3., 11., 4., 25., 8., 3., 17., 18., 23., 11., 12., 17., 10.,\n", | |
" 3., 16., 18., 21., 8., 3., 23., 18., 3., 7., 18., 2., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
" array([64.], dtype=float32),\n", | |
" \"mh dear,' saidd Alice, seriosuly, `I'll have nothign more to doi\",\n", | |
" \"my dear,' said Alice, seriously, `I'll have nothing more to do\")" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dataset_train[random.randint(0, len(dataset_train)-1)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def decode(text):\n", | |
" output = []\n", | |
" for val in text:\n", | |
" output.append(ALPHABET[int(val)])\n", | |
" if val == EOS:\n", | |
" break\n", | |
" return ''.join(output)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def batchify_list(elem):\n", | |
" output = []\n", | |
" for e in elem:\n", | |
" output.append(elem)\n", | |
" return output\n", | |
" \n", | |
"batchify = Tuple(Stack(), Stack(), Stack(), Stack(), batchify_list, batchify_list)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_data = gluon.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, last_batch='rollover', batchify_fn=batchify)\n", | |
"test_data = gluon.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True, last_batch='rollover', batchify_fn=batchify)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Network" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from utils.encoder_decoder import get_transformer_encoder_decoder" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class SoftmaxCEMaskedLoss(SoftmaxCELoss):\n", | |
" \"\"\"Wrapper of the SoftmaxCELoss that supports valid_length as the input\n", | |
" \"\"\"\n", | |
" def hybrid_forward(self, F, pred, label, valid_length): # pylint: disable=arguments-differ\n", | |
" \"\"\"\n", | |
" Parameters\n", | |
" ----------\n", | |
" F\n", | |
" pred : Symbol or NDArray\n", | |
" Shape (batch_size, length, V)\n", | |
" label : Symbol or NDArray\n", | |
" Shape (batch_size, length)\n", | |
" valid_length : Symbol or NDArray\n", | |
" Shape (batch_size, )\n", | |
" Returns\n", | |
" -------\n", | |
" loss : Symbol or NDArray\n", | |
" Shape (batch_size,)\n", | |
" \"\"\"\n", | |
" if self._sparse_label:\n", | |
" sample_weight = F.cast(F.expand_dims(F.ones_like(label), axis=-1), dtype=np.float32)\n", | |
" else:\n", | |
" sample_weight = F.ones_like(label)\n", | |
" sample_weight = F.SequenceMask(sample_weight,\n", | |
" sequence_length=valid_length,\n", | |
" use_sequence_length=True,\n", | |
" axis=1)\n", | |
" return super(SoftmaxCEMaskedLoss, self).hybrid_forward(F, pred, label, sample_weight)\n", | |
"\n", | |
"# pylint: disable=unused-argument\n", | |
"class _SmoothingWithDim(mx.operator.CustomOp):\n", | |
" def __init__(self, epsilon=0.1, axis=-1):\n", | |
" super(_SmoothingWithDim, self).__init__(True)\n", | |
" self._epsilon = epsilon\n", | |
" self._axis = axis\n", | |
"\n", | |
" def forward(self, is_train, req, in_data, out_data, aux):\n", | |
" inputs = in_data[0]\n", | |
" outputs = ((1 - self._epsilon) * inputs) + (self._epsilon / float(inputs.shape[self._axis]))\n", | |
" self.assign(out_data[0], req[0], outputs)\n", | |
"\n", | |
" def backward(self, req, out_grad, in_data, out_data, in_grad, aux):\n", | |
" self.assign(in_grad[0], req[0], (1 - self._epsilon) * out_grad[0])\n", | |
"\n", | |
"\n", | |
"@mx.operator.register('_smoothing_with_dim')\n", | |
"class _SmoothingWithDimProp(mx.operator.CustomOpProp):\n", | |
" def __init__(self, epsilon=0.1, axis=-1):\n", | |
" super(_SmoothingWithDimProp, self).__init__(True)\n", | |
" self._epsilon = float(epsilon)\n", | |
" self._axis = int(axis)\n", | |
"\n", | |
" def list_arguments(self):\n", | |
" return ['data']\n", | |
"\n", | |
" def list_outputs(self):\n", | |
" return ['output']\n", | |
"\n", | |
" def infer_shape(self, in_shape):\n", | |
" data_shape = in_shape[0]\n", | |
" output_shape = data_shape\n", | |
" return (data_shape,), (output_shape,), ()\n", | |
"\n", | |
" def declare_backward_dependency(self, out_grad, in_data, out_data):\n", | |
" return out_grad\n", | |
"\n", | |
" def create_operator(self, ctx, in_shapes, in_dtypes):\n", | |
" # create and return the CustomOp class.\n", | |
" return _SmoothingWithDim(self._epsilon, self._axis)\n", | |
"# pylint: enable=unused-argument\n", | |
"\n", | |
"\n", | |
"class LabelSmoothing(HybridBlock):\n", | |
" \"\"\"Applies label smoothing. See https://arxiv.org/abs/1512.00567.\n", | |
" Parameters\n", | |
" ----------\n", | |
" axis : int, default -1\n", | |
" The axis to smooth.\n", | |
" epsilon : float, default 0.1\n", | |
" The epsilon parameter in label smoothing\n", | |
" sparse_label : bool, default True\n", | |
" Whether input is an integer array instead of one hot array.\n", | |
" units : int or None\n", | |
" Vocabulary size. If units is not given, it will be inferred from the input.\n", | |
" prefix : str, default 'rnn_'\n", | |
" Prefix for name of `Block`s\n", | |
" (and name of weight if params is `None`).\n", | |
" params : Parameter or None\n", | |
" Container for weight sharing between cells.\n", | |
" Created if `None`.\n", | |
" \"\"\"\n", | |
" def __init__(self, axis=-1, epsilon=0.1, units=None,\n", | |
" sparse_label=True, prefix=None, params=None):\n", | |
" super(LabelSmoothing, self).__init__(prefix=prefix, params=params)\n", | |
" self._axis = axis\n", | |
" self._epsilon = epsilon\n", | |
" self._sparse_label = sparse_label\n", | |
" self._units = units\n", | |
"\n", | |
" def hybrid_forward(self, F, inputs, units=None): # pylint: disable=arguments-differ\n", | |
" \"\"\"\n", | |
" Parameters\n", | |
" ----------\n", | |
" F\n", | |
" inputs : Symbol or NDArray\n", | |
" Shape (batch_size, length) or (batch_size, length, V)\n", | |
" units : int or None\n", | |
" Returns\n", | |
" -------\n", | |
" smoothed_label : Symbol or NDArray\n", | |
" Shape (batch_size, length, V)\n", | |
" \"\"\"\n", | |
" if self._sparse_label:\n", | |
" assert units is not None or self._units is not None, \\\n", | |
" 'units needs to be given in function call or ' \\\n", | |
" 'instance initialization when sparse_label is False'\n", | |
" if units is None:\n", | |
" units = self._units\n", | |
" inputs = F.one_hot(inputs, depth=units)\n", | |
" if units is None and self._units is None:\n", | |
" return F.Custom(inputs, epsilon=self._epsilon, axis=self._axis,\n", | |
" op_type='_smoothing_with_dim')\n", | |
" else:\n", | |
" if units is None:\n", | |
" units = self._units\n", | |
" return ((1 - self._epsilon) * inputs) + (self._epsilon / units)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 149, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Denoiser(gluon.nn.HybridBlock):\n", | |
" def __init__(self, alphabet_size, embed_size=64, max_src_length=FEATURE_LEN, max_tgt_length=FEATURE_LEN):\n", | |
" super(Denoiser, self).__init__()\n", | |
" encoder, decoder = get_transformer_encoder_decoder(max_src_length=max_src_length, max_tgt_length=max_tgt_length, units=embed_size, num_heads=8, num_layers=2)\n", | |
" self.encoder = encoder\n", | |
" self.decoder = decoder\n", | |
" self.src_embedding = gluon.nn.HybridSequential()\n", | |
" with self.src_embedding.name_scope():\n", | |
" self.src_embedding.add(gluon.nn.Embedding(input_dim=alphabet_size, output_dim=embed_size),\n", | |
" gluon.nn.Dropout(rate=0.))\n", | |
" self.tgt_embedding = gluon.nn.HybridSequential()\n", | |
" with self.tgt_embedding.name_scope():\n", | |
" self.tgt_embedding.add(gluon.nn.Embedding(input_dim=alphabet_size, output_dim=embed_size),\n", | |
" gluon.nn.Dropout(rate=0.))\n", | |
" self.tgt_proj = gluon.nn.Dense(units=alphabet_size, flatten=False,\n", | |
" prefix='tgt_proj_')\n", | |
" \n", | |
" def hybrid_forward(self, F, src_seq, tgt_seq, src_valid_length, tgt_valid_length):\n", | |
" encoder_outputs, encoder_additional_outputs = self.encode(src_seq, valid_length=src_valid_length)\n", | |
" decoder_states = self.decoder.init_state_from_encoder(encoder_outputs, encoder_valid_length=src_valid_length)\n", | |
" tgt_embed = self.tgt_embedding(tgt_seq)\n", | |
" outputs, _, _ = self.decoder.decode_seq(tgt_embed, decoder_states, tgt_valid_length)\n", | |
" outputs = self.tgt_proj(outputs)\n", | |
" return outputs\n", | |
" \n", | |
" def encode(self, inputs, states=None, valid_length=None):\n", | |
" return self.encoder(self.src_embedding(inputs), states, valid_length)\n", | |
" \n", | |
" def decode_step(self, step_input, states):\n", | |
" step_output, states, step_additional_outputs = self.decoder(self.tgt_embedding(step_input), states)\n", | |
" step_output = self.tgt_proj(step_output)\n", | |
" return step_output, states, step_additional_outputs\n", | |
" \n", | |
" def decode_logprob(self, step_input, states):\n", | |
" out, states, _ = self.decode_step(step_input, states)\n", | |
" return out, states" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 150, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"denoiser = Denoiser(alphabet_size=len(ALPHABET))\n", | |
"#denoiser.load_parameters('model_checkpoint/denoiser.params', ctx=ctx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 151, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"denoiser.initialize(ctx=ctx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 152, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"label_smoothing = LabelSmoothing(epsilon=0.1, units=len(ALPHABET))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 153, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loss_function = SoftmaxCEMaskedLoss(sparse_label=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 154, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trainer = gluon.Trainer(denoiser.collect_params(), 'adam', {'learning_rate':0.001})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 155, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch [0], Loss 1.0904\n", | |
"[Typo ] inveted yet.'<EOS>\n", | |
"[Predicted] <BOS>insited yeta'<EOS>\n", | |
"[Correct ] <BOS>invited yet.'<EOS>\n", | |
"\n", | |
"Epoch [1], Loss 0.6172\n", | |
"[Typo ] whut vt miqht apear to others that what yu were or might have<EOS>\n", | |
"[Predicted] <BOS>what it might appear to others that what you were or might have<EOS>\n", | |
"[Correct ] <BOS>what it might appear to others that what you were or might have<EOS>\n", | |
"\n", | |
"Epoch [2], Loss 0.4213\n", | |
"[Typo ] `I'll fetch the executioner myself,' said tht King eagerly, and<EOS>\n", | |
"[Predicted] <BOS>`I'll fetch the executioner myself,' said the King eagerly, and<EOS>\n", | |
"[Correct ] <BOS>`I'll fetch the executioner myself,' said the King eagerly, and<EOS>\n", | |
"\n", | |
"Epoch [3], Loss 0.3381\n", | |
"[Typo ] fr instance, there's tha arch I've gt to go through next<EOS>\n", | |
"[Predicted] <BOS>for instance, there's the arch I've got to go through next<EOS>\n", | |
"[Correct ] <BOS>for instance, there's the arch I've got to go through next<EOS>\n", | |
"\n", | |
"Epoch [4], Loss 0.3070\n", | |
"[Typo ] `My uname is Alice, so please your Majesty,' said Alice very<EOS>\n", | |
"[Predicted] <BOS>`My name is Alice, so please your Majesty,' said Alice very<EOS>\n", | |
"[Correct ] <BOS>`My name is Alice, so please your Majesty,' said Alice very<EOS>\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"epochs = 5\n", | |
"for e in range(epochs):\n", | |
" loss = 0.\n", | |
" for i, (src, src_valid_length, tgt, tgt_valid_length, typo, label) in enumerate(train_data):\n", | |
" src = src.as_in_context(ctx)\n", | |
" tgt = tgt.as_in_context(ctx)\n", | |
" src_valid_length = src_valid_length.as_in_context(ctx).squeeze()\n", | |
" tgt_valid_length = tgt_valid_length.as_in_context(ctx).squeeze()\n", | |
" with autograd.record():\n", | |
" output = denoiser(src, tgt, src_valid_length, tgt_valid_length)\n", | |
" smoothed_label = label_smoothing(tgt)\n", | |
" ls = loss_function(output, smoothed_label, tgt_valid_length).mean()\n", | |
" ls.backward()\n", | |
" trainer.step(src.shape[0])\n", | |
" loss += ls.asscalar()\n", | |
" print(\"Epoch [{}], Loss {:.4f}\".format(e, loss/(i+1)))\n", | |
" print(\"[Typo ] {}\".format(decode(src[0].asnumpy())))\n", | |
" print(\"[Predicted] {}\".format(decode(output[0].asnumpy().argmax(axis=1))))\n", | |
" print(\"[Correct ] {}\".format(decode(tgt[0].asnumpy())))\n", | |
" print()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 156, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"`My uname is Alice, so please your Majesty,' said Alice very\n" | |
] | |
} | |
], | |
"source": [ | |
"src_seq = src[0:1]\n", | |
"src_valid_length = src_valid_length[0]\n", | |
"print(typo[0][0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 157, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"encoder_outputs, _ = denoiser.encode(src_seq, valid_length=src_valid_length)\n", | |
"states = denoiser.decoder.init_state_from_encoder(encoder_outputs, encoder_valid_length=src_valid_length)\n", | |
"inputs = mx.nd.full(shape=(1,), ctx=src_seq.context, dtype=np.float32, val=BOS)\n", | |
"output = []" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 158, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS>\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in range(0,20):\n", | |
" output.append(int(inputs.asscalar()))\n", | |
" next_input, states = denoiser.decode_logprob(inputs, states)\n", | |
" inputs = next_input.argmax(axis=1)\n", | |
"output.append(int(inputs.asscalar()))\n", | |
"print(''.join([ALPHABET[int(c)] for c in output]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Manual Testing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 159, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
" scorer = nlp.model.BeamSearchScorer(alpha=0, K=5, from_logits=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 160, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"eos_id = EOS\n", | |
"beam_sampler = nlp.model.BeamSearchSampler(beam_size=5,\n", | |
" decoder=denoiser.decode_logprob,\n", | |
" eos_id=eos_id,\n", | |
" scorer=scorer,\n", | |
" max_length=20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 161, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def generate_sequences(sampler, inputs, begin_states, num_print_outcomes):\n", | |
" samples, scores, valid_lengths = sampler(inputs, begin_states)\n", | |
" samples = samples[0].asnumpy()\n", | |
" scores = scores[0].asnumpy()\n", | |
" valid_lengths = valid_lengths[0].asnumpy()\n", | |
" print('Generation Result:')\n", | |
" for i in range(num_print_outcomes):\n", | |
" print(decode(samples[i][:valid_lengths[i]]), scores[i])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 163, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Generation Result:\n", | |
"<BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><BOS><EOS> -5.713097\n", | |
"<BOS><BOS><BOS>eeeeeeeeeeeeeeeeee<EOS> -7.2956805\n", | |
"<BOS><BOS><BOS><BOS>eeeeeeeeeeeeeeeee<EOS> -7.3000755\n", | |
"<BOS><BOS><BOS><BOS><BOS>eeeeeeeeeeeeeeee<EOS> -7.317902\n", | |
"<BOS><BOS>eeeeeeeeeeeeeeeeeee<EOS> -7.362971\n" | |
] | |
} | |
], | |
"source": [ | |
"states = denoiser.decoder.init_state_from_encoder(encoder_outputs, encoder_valid_length=src_valid_length)\n", | |
"inputs = mx.nd.full(shape=(1,), ctx=src_seq.context, dtype=np.float32, val=BOS)\n", | |
"generate_sequences(beam_sampler, inputs, states, 5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment