Last active
March 29, 2021 14:49
-
-
Save anton-l/e927041c45f97e2e4700af66bc2ca7e1 to your computer and use it in GitHub Desktop.
Evaluation script for wav2vec
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
import torch | |
import torchaudio | |
import urllib.request | |
import tarfile | |
import pandas as pd | |
from tqdm.auto import tqdm | |
from datasets import load_metric | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
# Download the raw data instead of using HF datasets to save disk space | |
data_url = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/hu.tar.gz" | |
filestream = urllib.request.urlopen(data_url) | |
data_file = tarfile.open(fileobj=filestream, mode="r|gz") | |
data_file.extractall() | |
wer = load_metric("wer") | |
processor = Wav2Vec2Processor.from_pretrained("anton-l/wav2vec2-large-xlsr-53-hungarian") | |
model = Wav2Vec2ForCTC.from_pretrained("anton-l/wav2vec2-large-xlsr-53-hungarian") | |
model.to("cuda") | |
cv_test = pd.read_csv("cv-corpus-6.1-2020-12-11/hu/test.tsv", sep='\t') | |
clips_path = "cv-corpus-6.1-2020-12-11/hu/clips/" | |
def clean_sentence(sent): | |
sent = sent.lower() | |
# replace non-alpha characters with space | |
sent = "".join(ch if ch.isalpha() else " " for ch in sent) | |
# remove repeated spaces | |
sent = " ".join(sent.split()) | |
return sent | |
targets = [] | |
preds = [] | |
for i, row in tqdm(cv_test.iterrows(), total=cv_test.shape[0]): | |
row["sentence"] = clean_sentence(row["sentence"]) | |
speech_array, sampling_rate = torchaudio.load(clips_path + row["path"]) | |
resampler = torchaudio.transforms.Resample(sampling_rate, 16_000) | |
row["speech"] = resampler(speech_array).squeeze().numpy() | |
inputs = processor(row["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits | |
pred_ids = torch.argmax(logits, dim=-1) | |
targets.append(row["sentence"]) | |
preds.append(processor.batch_decode(pred_ids)[0]) | |
# free up some memory | |
del model | |
del processor | |
del cv_test | |
gc.collect() | |
print("WER: {:2f}".format(100 * wer.compute(predictions=preds, references=targets))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment