Skip to content

Instantly share code, notes, and snippets.

@johnmeade
Last active April 18, 2021 00:15
Show Gist options
  • Save johnmeade/c4de49429502771a304dd7b82e864838 to your computer and use it in GitHub Desktop.
Save johnmeade/c4de49429502771a304dd7b82e864838 to your computer and use it in GitHub Desktop.
Multi-language ASR using Huggingface transformer models.
"""
Multi-language ASR using Huggingface transformer models.
Python dependencies:
pip install transformers==4.5.0 librosa soundfile torch
"""
from typing import NamedTuple
from functools import lru_cache
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
from librosa.core import resample
import soundfile as sf
import torch
class Model(NamedTuple):
lang: str
name: str
MODELS = {
# find and add more language paths here:
# https://huggingface.co/models?filter=wav2vec2,pytorch&pipeline_tag=automatic-speech-recognition
Model(lang="eng", name="facebook/wav2vec2-large-960h-lv60-self"),
Model(lang="fra", name="facebook/wav2vec2-large-xlsr-53-french"),
Model(lang="spa", name="facebook/wav2vec2-large-xlsr-53-spanish"),
Model(lang="nld", name="facebook/wav2vec2-large-xlsr-53-dutch"),
Model(lang="deu", name="facebook/wav2vec2-large-xlsr-53-german"),
}
W2V_SR = 16_000
@lru_cache()
def _get_models(lang):
# find model name
matches = [m.name for m in MODELS if m.lang == lang]
if not any(matches):
raise ValueError("Could not find a model for this language")
name = matches[0]
# load model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained(name)
model = Wav2Vec2ForCTC.from_pretrained(name)
return tokenizer, model
def _loadwav(wavfn):
# load wav
wav, sr = sf.read(wavfn)
# ensure mono
if wav.ndim > 1:
wav = wav[:, 0]
# ensure samplerate
if sr != W2V_SR:
wav = resample(wav, sr, W2V_SR)
return wav
def rec_files(wavfns, lang):
# load wavs
wavs = list(map(_loadwav, wavfns))
# load models
tokenizer, model = _get_models(lang)
# tokenize
input_values = tokenizer(wavs, return_tensors="pt", padding="longest").input_values
# retrieve logits
logits = model(input_values).logits
# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)
return transcription
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment