Internal Language Model Subtraction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fast_beam_search( | |
model: Transducer, | |
decoding_graph: k2.Fsa, | |
encoder_out: torch.Tensor, | |
encoder_out_lens: torch.Tensor, | |
beam: float, | |
max_states: int, | |
max_contexts: int, | |
temperature: float = 1.0, | |
ILM_lambda: float = 0.0, | |
) -> k2.Fsa: | |
"""It limits the maximum number of symbols per frame to 1. | |
Args: | |
model: | |
An instance of `Transducer`. | |
decoding_graph: | |
Decoding graph used for decoding, may be a TrivialGraph or a LG. | |
encoder_out: | |
A tensor of shape (N, T, C) from the encoder. | |
encoder_out_lens: | |
A tensor of shape (N,) containing the number of frames in `encoder_out` | |
before padding. | |
beam: | |
Beam value, similar to the beam used in Kaldi.. | |
max_states: | |
Max states per stream per frame. | |
max_contexts: | |
Max contexts pre stream per frame. | |
temperature: | |
Softmax temperature. | |
Returns: | |
Return an FsaVec with axes [utt][state][arc] containing the decoded | |
lattice. Note: When the input graph is a TrivialGraph, the returned | |
lattice is actually an acceptor. | |
""" | |
assert encoder_out.ndim == 3 | |
context_size = model.decoder.context_size | |
vocab_size = model.decoder.vocab_size | |
B, T, C = encoder_out.shape | |
config = k2.RnntDecodingConfig( | |
vocab_size=vocab_size, | |
decoder_history_len=context_size, | |
beam=beam, | |
max_contexts=max_contexts, | |
max_states=max_states, | |
) | |
individual_streams = [] | |
for i in range(B): | |
individual_streams.append(k2.RnntDecodingStream(decoding_graph)) | |
decoding_streams = k2.RnntDecodingStreams(individual_streams, config) | |
encoder_out = model.joiner.encoder_proj(encoder_out) | |
for t in range(T): | |
# shape is a RaggedShape of shape (B, context) | |
# contexts is a Tensor of shape (shape.NumElements(), context_size) | |
shape, contexts = decoding_streams.get_contexts() | |
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 | |
contexts = contexts.to(torch.int64) | |
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) | |
decoder_out = model.decoder(contexts, need_pad=False) | |
decoder_out = model.joiner.decoder_proj(decoder_out) | |
# current_encoder_out is of shape | |
# (shape.NumElements(), 1, joiner_dim) | |
# fmt: off | |
current_encoder_out = torch.index_select( | |
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) | |
) | |
# fmt: on | |
logits = model.joiner( | |
current_encoder_out.unsqueeze(2), | |
decoder_out.unsqueeze(1), | |
project_input=False, | |
) | |
logits = logits.squeeze(1).squeeze(1) | |
log_probs = (logits / temperature).log_softmax(dim=-1) | |
# Internal LM subtraction | |
current_ILM_encoder_out = torch.zeros_like(current_encoder_out) | |
ILM_logits = model.joiner(current_ILM_encoder_out, decoder_out, project=False) | |
ILM_logits = ILM_logits.squeeze(1).squeeze(1) | |
ILM_log_probs = ILM_logits.log_softmax(dim=-1) | |
log_probs -= ILM_log_probs * ILM_lambda | |
decoding_streams.advance(log_probs) | |
decoding_streams.terminate_and_flush_to_streams() | |
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) | |
return lattice |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment