Skip to content

Instantly share code, notes, and snippets.

@anton-l
Last active March 29, 2021 14:49
Show Gist options
  • Save anton-l/e927041c45f97e2e4700af66bc2ca7e1 to your computer and use it in GitHub Desktop.
Save anton-l/e927041c45f97e2e4700af66bc2ca7e1 to your computer and use it in GitHub Desktop.
Evaluation script for wav2vec
import gc
import torch
import torchaudio
import urllib.request
import tarfile
import pandas as pd
from tqdm.auto import tqdm
from datasets import load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
# Download the raw data instead of using HF datasets to save disk space
data_url = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/hu.tar.gz"
filestream = urllib.request.urlopen(data_url)
data_file = tarfile.open(fileobj=filestream, mode="r|gz")
data_file.extractall()
wer = load_metric("wer")
processor = Wav2Vec2Processor.from_pretrained("anton-l/wav2vec2-large-xlsr-53-hungarian")
model = Wav2Vec2ForCTC.from_pretrained("anton-l/wav2vec2-large-xlsr-53-hungarian")
model.to("cuda")
cv_test = pd.read_csv("cv-corpus-6.1-2020-12-11/hu/test.tsv", sep='\t')
clips_path = "cv-corpus-6.1-2020-12-11/hu/clips/"
def clean_sentence(sent):
sent = sent.lower()
# replace non-alpha characters with space
sent = "".join(ch if ch.isalpha() else " " for ch in sent)
# remove repeated spaces
sent = " ".join(sent.split())
return sent
targets = []
preds = []
for i, row in tqdm(cv_test.iterrows(), total=cv_test.shape[0]):
row["sentence"] = clean_sentence(row["sentence"])
speech_array, sampling_rate = torchaudio.load(clips_path + row["path"])
resampler = torchaudio.transforms.Resample(sampling_rate, 16_000)
row["speech"] = resampler(speech_array).squeeze().numpy()
inputs = processor(row["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)
targets.append(row["sentence"])
preds.append(processor.batch_decode(pred_ids)[0])
# free up some memory
del model
del processor
del cv_test
gc.collect()
print("WER: {:2f}".format(100 * wer.compute(predictions=preds, references=targets)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment