Skip to content

Instantly share code, notes, and snippets.

@pzelasko
Created April 5, 2024 23:44
Show Gist options
  • Save pzelasko/168dbc53bf6b80f97bc20abb828b5e51 to your computer and use it in GitHub Desktop.
Save pzelasko/168dbc53bf6b80f97bc20abb828b5e51 to your computer and use it in GitHub Desktop.
Analyze where the most errors are found in ASR transcripts using a NeMo manifest with `text` and `pred_text` keys.
"""
Make sure to first run:
$ pip install click pandas lhotse kaldialign
"""
import click
import pandas as pd
from lhotse.serialization import load_jsonl
from kaldialign import align, bootstrap_wer_ci
EPS = '*'
@click.command()
@click.argument("manifest", type=click.Path(exists=True))
@click.option("-c", "--cer", is_flag=True, help="If true, we won't split the text by whitespace and treat each character as a separate symbol..")
@click.option("-n", "--num-splits", default=3, type=int, help="Number of position splits for detailed analysis. You may increase this number if your data consists mostly of long utterances.")
def analyse_results(manifest: str, cer: bool, num_splits: int) -> None:
items = [(item['text'], item['pred_text']) for item in load_jsonl(manifest)]
stats = []
tot_sym = 0
for uttidx, (ref, hyp) in enumerate(items):
if not cer:
ref, hyp = ref.split(), hyp.split()
tot_sym += len(ref)
ali = align(ref, hyp, EPS)
tot = len(ali)
for pos, (r, h) in enumerate(ali):
if r == h:
continue
stat = {"ref": r, "hyp": h, "pos": pos, "tot": tot, "relpos": pos / tot, "uttidx": uttidx}
if r == EPS:
stat["kind"] = "ins"
elif h == EPS:
stat["kind"] = "del"
else:
stat["kind"] = "sub"
stats.append(stat)
refs, hyps = zip(*items)
ans = bootstrap_wer_ci([r.split() for r in refs], [h.split() for h in hyps])
click.echo(f"Boostrap WER={ans['wer']:.2%}+/-{ans['ci95']:.2%} [WER@p0.025={ans['ci95min']:.2%} - WER@p0.975={ans['ci95max']:.2%}]")
df = pd.DataFrame(stats)
tot_err = len(df)
KINDS = "del ins sub".split()
tot_kind = {kind: len(df.query(f'kind == "{kind}"')) for kind in KINDS}
msg = "\t* "
for kind, val in tot_kind.items():
msg += f"{kind}={val/tot_sym:.2%} "
click.echo(msg)
click.echo("Error location analysis [relative to utterance length]:")
pieces = range(num_splits)
for b, e in zip(range(num_splits), range(1, num_splits + 1)):
b = b / num_splits
if e == num_splits:
e = 1.0001 # last loop iter, include last symbol pos
else:
e = e / num_splits
subdf = df[(b <= df.relpos) & (df.relpos < e)]
num_err = len(subdf)
click.echo(f"[{b:.2f} - {e:.2f}]")
click.echo(f"\t* {num_err / tot_err:.1%} of all errors.")
for kind in KINDS:
num_kind = len(subdf.query(f"kind == '{kind}'"))
click.echo(f"\t* {num_kind / num_err:.1%} are of type '{kind}' (this constitutes {num_kind / tot_kind[kind]:.1%} of all '{kind}').")
if __name__ == "__main__":
analyse_results()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment