Created
January 8, 2024 12:20
-
-
Save arturbosch/c3fc75f21f36825bca628b0ead5d0b07 to your computer and use it in GitHub Desktop.
asr evaluation table using jiwer for WER and CER
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from jiwer import cer | |
from jiwer import transforms as tr | |
from jiwer import wer | |
wer_pipe = tr.Compose( | |
[ | |
tr.ToLowerCase(), | |
tr.RemovePunctuation(), | |
tr.Strip(), | |
tr.ReduceToListOfListOfWords(), | |
] | |
) | |
cer_pipe = tr.Compose( | |
[ | |
tr.ToLowerCase(), | |
tr.RemovePunctuation(), | |
tr.Strip(), | |
tr.ReduceToListOfListOfChars(), | |
] | |
) | |
def print_row(sentence_num, model_name, delimiter, reference, prediction): | |
w_error = wer(reference, prediction, wer_pipe, wer_pipe) | |
c_error = cer(reference, prediction, cer_pipe, cer_pipe) | |
print( | |
f"{sentence_num}{delimiter}{reference}{delimiter}{prediction}{delimiter}{model_name}{delimiter}{w_error * 100:.2f}{delimiter}{c_error * 100:.2f}" | |
) | |
def print_table(ref_generator, pred_generator, model_name, delimiter): | |
if (WITH_HEADER): | |
print( | |
f"WAV{delimiter}REFERENCE{delimiter}PREDICTION{delimiter}MODEL{delimiter}WER{delimiter}CER" | |
) | |
for num, (reference, prediction) in enumerate(zip(ref_generator, pred_generator)): | |
print_row( | |
str(num + 1) + ".wav", | |
model_name, | |
delimiter, | |
reference.strip(), | |
prediction.strip(), | |
) | |
def load_file(file_name): | |
with open(file_name, "r", encoding="utf-8") as file: | |
for line in file: | |
yield (line) | |
WITH_HEADER = False | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-r", "--reference", help="Path to the reference txt file.") | |
parser.add_argument("-p", "--predictions", help="Path to the predictions txt file.") | |
parser.add_argument("-m", "--model", help="Model name to use for the table.") | |
parser.add_argument( | |
"-d", "--delimiter", help="Delimiter to use. Defaults to TAB.", default="\t" | |
) | |
parser.add_argument("--header", default=False, action="store_true") | |
args = parser.parse_args() | |
delim = args.delimiter | |
ref_gen = load_file(args.reference) | |
pred_gen = load_file(args.predictions) | |
model = args.model | |
if (args.header): | |
WITH_HEADER = True | |
print_table(ref_gen, pred_gen, model, delim) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment