Skip to content

Instantly share code, notes, and snippets.

@arturbosch
Created January 8, 2024 12:20
Show Gist options
  • Save arturbosch/c3fc75f21f36825bca628b0ead5d0b07 to your computer and use it in GitHub Desktop.
Save arturbosch/c3fc75f21f36825bca628b0ead5d0b07 to your computer and use it in GitHub Desktop.
asr evaluation table using jiwer for WER and CER
import argparse
from jiwer import cer
from jiwer import transforms as tr
from jiwer import wer
wer_pipe = tr.Compose(
[
tr.ToLowerCase(),
tr.RemovePunctuation(),
tr.Strip(),
tr.ReduceToListOfListOfWords(),
]
)
cer_pipe = tr.Compose(
[
tr.ToLowerCase(),
tr.RemovePunctuation(),
tr.Strip(),
tr.ReduceToListOfListOfChars(),
]
)
def print_row(sentence_num, model_name, delimiter, reference, prediction):
w_error = wer(reference, prediction, wer_pipe, wer_pipe)
c_error = cer(reference, prediction, cer_pipe, cer_pipe)
print(
f"{sentence_num}{delimiter}{reference}{delimiter}{prediction}{delimiter}{model_name}{delimiter}{w_error * 100:.2f}{delimiter}{c_error * 100:.2f}"
)
def print_table(ref_generator, pred_generator, model_name, delimiter):
if (WITH_HEADER):
print(
f"WAV{delimiter}REFERENCE{delimiter}PREDICTION{delimiter}MODEL{delimiter}WER{delimiter}CER"
)
for num, (reference, prediction) in enumerate(zip(ref_generator, pred_generator)):
print_row(
str(num + 1) + ".wav",
model_name,
delimiter,
reference.strip(),
prediction.strip(),
)
def load_file(file_name):
with open(file_name, "r", encoding="utf-8") as file:
for line in file:
yield (line)
WITH_HEADER = False
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-r", "--reference", help="Path to the reference txt file.")
parser.add_argument("-p", "--predictions", help="Path to the predictions txt file.")
parser.add_argument("-m", "--model", help="Model name to use for the table.")
parser.add_argument(
"-d", "--delimiter", help="Delimiter to use. Defaults to TAB.", default="\t"
)
parser.add_argument("--header", default=False, action="store_true")
args = parser.parse_args()
delim = args.delimiter
ref_gen = load_file(args.reference)
pred_gen = load_file(args.predictions)
model = args.model
if (args.header):
WITH_HEADER = True
print_table(ref_gen, pred_gen, model, delim)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment