Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Last active November 7, 2022 06:07
Show Gist options
  • Save TeaPoly/234429e6c2d74d10fcb4987bc541d528 to your computer and use it in GitHub Desktop.
Save TeaPoly/234429e6c2d74d10fcb4987bc541d528 to your computer and use it in GitHub Desktop.
The implementation of Minimum Word Error Rate Training loss (MWER) based on negative sampling strategy from <Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition>
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright 2022 Lucky Wong
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Minimum Word Error Rate Training loss
<Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition>
https://arxiv.org/abs/2206.08317
<Minimum Word Error Rate Training for Attention-based Sequence-to-Sequence Models>
https://arxiv.org/abs/1712.01818
"""
from typing import List, Optional, Tuple
import torch
MIN_LOG_VAL = torch.tensor(-float('inf'))
IGNORE_ID = -1
def make_pad_mask(lengths: torch.Tensor, max_len: int = None) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = int(lengths.size(0))
if max_len is None:
max_len = int(lengths.max().item())
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def create_sampling_mask(log_softmax, n):
"""
Generate sampling mask
# Ref: <Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition>
# https://arxiv.org/abs/2206.08317
Args:
log_softmax: log softmax inputs, float32 (batch, maxlen_out, vocab_size)
n: candidate paths num, int32
Return:
sampling_mask: the sampling mask (nbest, batch, maxlen_out, vocab_size)
"""
b, s, v = log_softmax.size()
# Generate random mask
nbest_random_mask = torch.randint(
0, 2, (n, b, s, v), device=log_softmax.device
)
# Greedy search decoding for best path
top1_score_indices = log_softmax.argmax(dim=-1).squeeze(-1)
# Genrate top 1 score token mask
top1_score_indices_mask = torch.zeros(
(b, s, v), dtype=torch.int).to(log_softmax.device)
top1_score_indices_mask.scatter_(-1, top1_score_indices.unsqueeze(-1), 1)
# Genrate sampling mask by applying random mask to top 1 score token
sampling_mask = nbest_random_mask*top1_score_indices_mask.unsqueeze(0)
return sampling_mask
def negative_sampling_decoder(
logit: torch.Tensor,
nbest: int = 4,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate multiple candidate paths by negative sampling strategy
# Ref: <Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition>
# https://arxiv.org/abs/2206.08317
Args:
logit: logit inputs, float32 (batch, maxlen_out, vocab_size)
nbest: candidate paths num, int32
masks: logit lengths, (batch, maxlen_out)
Return:
nbest_log_distribution: the N-BEST distribution of candidate path (nbest, batch)
nbest_pred: the NBEST candidate path (nbest, batch, maxlen_out)
"""
# Using log-softmax for probability distribution
log_softmax = torch.nn.functional.log_softmax(logit, dim=-1)
# Generate sampling mask
with torch.no_grad():
sampling_mask = create_sampling_mask(log_softmax, nbest)
# Randomly masking top1 score with -float('inf')
# (nbest, batch, maxlen_out, vocab_size)
nbest_log_softmax = torch.where(
sampling_mask != 0, MIN_LOG_VAL.type_as(log_softmax), log_softmax)
# Greedy search decoding for sampling log softmax
nbest_logsoftmax, nbest_pred = nbest_log_softmax.topk(1)
nbest_pred = nbest_pred.squeeze(-1)
nbest_logsoftmax = nbest_logsoftmax.squeeze(-1)
# Construct N-BEST log PDF
# FIXME (huanglk): Ignore irrelevant probabilities
# (n, b, s) -> (n, b): log(p1*p2*...pn) = log(p1)+log(p2)+...log(pn)
nbest_log_distribution = torch.sum(
nbest_logsoftmax.masked_fill(masks, 0), -1)
return nbest_log_distribution, nbest_pred
def compute_mwer_loss(
nbest_log_distribution=torch.Tensor,
nbest_pred=torch.Tensor,
tgt=torch.Tensor,
tgt_lens=torch.Tensor
):
"""
Compute Minimum Word Error Rate Training loss.
# Ref: <Minimum Word Error Rate Training for Attention-based Sequence-to-Sequence Models>
# https://arxiv.org/abs/1712.01818
Args:
nbest_log_distribution: the N-BEST distribution of candidate path (nbest, batch)
nbest_pred: the NBEST candidate path (nbest, batch, maxlen_out)
tgt: padded target token ids, int32 (batch, maxlen_out)
tgt_lens: target token lengths of this batch (batch,)
Return:
loss: normalized MWER loss (batch,)
"""
n, b, s = nbest_pred.size()
# necessary to filter irrelevant length
# (b,) -> (b, s)
# not include <eos/sos>
masks = make_pad_mask(tgt_lens, max_len=tgt.size()[1])
tgt = tgt.masked_fill(masks, IGNORE_ID)
# (n, b, s)
nbest_pred = nbest_pred.masked_fill(masks, IGNORE_ID)
# Construct number of word errors
# (b, s) -> (n, b, s)
tgt = tgt.unsqueeze(0).repeat(n, 1, 1)
# convert to float for normalize
# (n, b, s) -> (n, b)
nbest_word_err_num = torch.sum((tgt != nbest_pred), -1).float()
# Computes log distribution
# (n, b) -> (b,): log( p1+p2+...+pn ) = log( exp(log_p1)+exp(log_p2)+...+exp(log_pn) )
sum_nbest_log_distribution = torch.logsumexp(nbest_log_distribution, 0)
# Re-normalized over just the N-best hypotheses.
# (n, b) - (b,) -> (n, b): exp(log_p)/exp(log_p_sum) = exp(log_p-log_p_sum)
normal_nbest_distribution = torch.exp(
nbest_log_distribution-sum_nbest_log_distribution)
# Average number of word errors over the N-best hypohtheses
# (n, b) -> (b)
mean_word_err_num = torch.mean(nbest_word_err_num, 0)
# print("mean_word_err_num:", mean_word_err_num)
# Re-normalized error word number over just the N-best hypotheses
# (n, b) - (b,) -> (n, b)
normal_nbest_word_err_num = nbest_word_err_num - mean_word_err_num
# Expected number of word errors over the training set.
# (n, b) -> (b,)
mwer_loss = torch.sum(normal_nbest_distribution *
normal_nbest_word_err_num, 0)
return mwer_loss
class Seq2seqMwerLoss(torch.nn.Module):
"""Minimum Word Error Rate Training loss based on the negative sampling strategy
<Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition>
https://arxiv.org/abs/2206.08317
<Minimum Word Error Rate Training for Attention-based Sequence-to-Sequence Models>
https://arxiv.org/abs/1712.01818
Args:
candidate_paths_num (int): The number of candidate paths.
"""
def __init__(
self,
candidate_paths_num: int = 4,
reduction="mean",
):
super().__init__()
self.candidate_paths_num = candidate_paths_num
self.reduction = reduction
def forward(self, logit: torch.Tensor, tgt: torch.Tensor, tgt_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
logit: logit (batch, maxlen_out, vocab_size)
tgt: padded target token ids, int64 (batch, maxlen_out)
tgt_lens: target lengths of this batch (batch)
Return:
loss: normalized MWER loss
"""
assert tgt_lens.size()[0] == tgt.size()[0] == logit.size()[0]
assert logit.size()[1] == tgt.size()[1]
# Randomly mask the top1 score to generate multiple candidate paths
masks = make_pad_mask(tgt_lens, max_len=tgt.size()[1])
nbest_log_distribution, nbest_pred = negative_sampling_decoder(
logit, self.candidate_paths_num, masks)
# Compute MWER loss
mwer_loss = compute_mwer_loss(
nbest_log_distribution, nbest_pred, tgt, tgt_lens)
if self.reduction == "sum":
return torch.sum(mwer_loss)
elif self.reduction == "mean":
return torch.mean(mwer_loss)
else:
return mwer_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment