Skip to content

Instantly share code, notes, and snippets.

@csukuangfj
Last active August 22, 2022 16:48
Show Gist options
  • Save csukuangfj/c68697cd144c8f063cc7ec4fd885fd6f to your computer and use it in GitHub Desktop.
Save csukuangfj/c68697cd144c8f063cc7ec4fd885fd6f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
from snowfall.training.ctc_graph import build_ctc_topo2
from speechbrain.pretrained import EncoderDecoderASR
import k2
import torch
def load_model():
model = EncoderDecoderASR.from_hparams(
source="speechbrain/asr-transformer-transformerlm-librispeech",
savedir="pretrained_models/asr-transformer-transformerlm-librispeech",
# run_opts={'device': 'cuda:0'},
)
return model
@torch.no_grad()
def main():
model = load_model()
device = model.device
# See https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech/blob/main/example.wav
sound_file = './example.wav'
wav = model.load_audio(sound_file)
# wav is a 1-d tensor, e.g., [52173]
wavs = wav.unsqueeze(0).float().to(device)
# wavs is a 2-d tensor, e.g., [1, 52173]
wav_lens = torch.tensor([1.0])
wav_lens = wav_lens.to(device)
encoder_out = model.modules.encoder(wavs, wav_lens)
# encoder_out.shape [N, T, C], e.g., [1, 82, 768]
logits = model.hparams.ctc_lin(encoder_out)
# logits.shape [N, T, C], e.g., [1, 82, 5000]
log_probs = model.hparams.log_softmax(logits)
# log_probs.shape [N, T, C], e.g., [1, 82, 5000]
vocab_size = model.tokenizer.vocab_size()
ctc_topo = build_ctc_topo2(list(range(vocab_size)))
ctc_topo = k2.create_fsa_vec([ctc_topo]).to(device)
supervision_segments = torch.tensor([[0, 0, log_probs.size(1)]],
dtype=torch.int32)
dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments)
lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0, 8, 30,
10000)
best_path = k2.shortest_path(lattices, True)
aux_labels = best_path[0].aux_labels
aux_labels = aux_labels[aux_labels.nonzero().squeeze()]
# The last entry is -1, so remove it
aux_labels = aux_labels[:-1]
hyp = model.tokenizer.decode(aux_labels.tolist())
print(hyp)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment