Skip to content

Instantly share code, notes, and snippets.

@sunprinceS
Created February 23, 2021 02:11
Show Gist options
  • Save sunprinceS/d5288cb14b4f0ab4404254b84148277c to your computer and use it in GitHub Desktop.
Save sunprinceS/d5288cb14b4f0ab4404254b84148277c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import re
import yaml
import json
from pathlib import Path
from comet_ml import ExistingExperiment
from src.marcos import *
from src.utils import run_cmd
import src.monitor.logger as logger
parser = argparse.ArgumentParser(description='Evaluate the decoded hypothesis via sclite')
parser.add_argument('--config', type=str, help='Path to config file', required=True)
parser.add_argument('--lang', choices=AVAIL_LANGS, required=True)
parser.add_argument('--algo', required=True, choices=['no','reptile','multi','fomaml'])
parser.add_argument('--model_name', required=True, choices=['blstm','las'])
parser.add_argument('--eval_suffix', default=None, type=str, help='Evaluation suffix', required=True)
parser.add_argument('--decode_mode', default='greedy', type=str, choices=['greedy','beam','lm_beam'])
parser.add_argument('--pretrain_suffix',type=str, help='Pretrain model suffix', nargs='?',const=None)
parser.add_argument('--runs', type=int, help='run suffix', nargs='?',const=0)
parser.add_argument('--decode_suffix', default=None, type=str) # will remove later
paras = parser.parse_args()
decode_suffix = f"{paras.decode_mode}_decode_{paras.decode_suffix}" if paras.decode_suffix else f"{paras.decode_mode}_decode"
# Load some prerequisites
with open(Path('data','lang.json'),'r') as fin:
id2lang = json.load(fin)
config = yaml.safe_load(open(paras.config,'r'))
if paras.pretrain_suffix is None:
if paras.algo == 'no':
paras.pretrain_suffix = paras.eval_suffix
else:
assert False, "pretrain_suffix should be specified if using pretrained model"
# Write to corresponding directory and comet exp
cur_path = Path.cwd()
log_dir = Path(cur_path, LOG_DIR, 'evaluation',
config['solver']['setting'], paras.algo, \
paras.pretrain_suffix, paras.eval_suffix, \
id2lang[paras.lang], str(paras.runs))
decode_dir = log_dir.joinpath(decode_suffix)
with open(Path(log_dir,'exp_key'),'r') as f:
exp_key = f.read().strip()
comet_exp = ExistingExperiment(previous_experiment=exp_key,
project_name=COMET_PROJECT_NAME,
workspace=COMET_WORKSPACE,
auto_output_logging=None,
auto_metric_logging=None,
display_summary=False,
)
id2ch = dict()
ch2id = dict()
id2ch[0] = ''
ch2id[''] = 0
with open(Path(config['solver']['data_root'],id2lang[paras.lang],'train_units.txt')) as fin:
for line in fin.readlines():
ch,i = line.split(' ')
id2ch[int(i)] = ch
ch2id[ch] = int(i)
non_syms = list()
with open(Path(config['solver']['data_root'],id2lang[paras.lang],'non_lang_syms.txt')) as fin:
for line in fin.readlines():
non_syms.append(line.rstrip())
non_syms_ids = [ch2id[ch] for ch in non_syms]
space_idx = ch2id['<space>']
### CER calculation
def cer_filt(s):
ret = list(map(int,s.split(' ')))
ret = [i for i in ret if i < len(id2ch)] # to filter out <sos> <eos> idx
ret = [id2ch[i] for i in ret if i not in non_syms_ids]
# ret = [id2ch[i] for i in ret]
# print(ret)
return ' '.join(ret)
### cal CER
logger.notice("CER calculating...")
with open(Path(decode_dir,'best-hyp'),'r') as hyp_ref_in, \
open(Path(decode_dir,'hyp.trn'),'w') as hyp_out, \
open(Path(decode_dir,'ref.trn'),'w') as ref_out:
for i,line in enumerate(hyp_ref_in.readlines()):
foo = line.rstrip().split('\t')
if len(foo) == 1:
print(f"{cer_filt(foo[0])} ({i//1000}k_{i})", file=ref_out)
print(f"({i//1000}k_{i})", file=hyp_out)
elif len(foo) == 2:
ref = foo[0]
hyp = foo[1]
print(f"{cer_filt(ref)} ({i//1000}k_{i})", file=ref_out)
print(f"{cer_filt(hyp)} ({i//1000}k_{i})", file=hyp_out)
else:
raise ValueError("at most only ref and hyp")
res = run_cmd(['sclite','-r',Path(decode_dir,'ref.trn'),'trn', '-h', Path(decode_dir,'hyp.trn'),'trn','-i','rm','-o','all','stdout'])
logger.log(f"Write result to {Path(decode_dir,'result.txt')}",prefix='info')
with open(Path(decode_dir,'result.txt'),'w') as fout:
print(res, file=fout)
er_rate = run_cmd(['grep','-e','Avg','-e','SPKR','-m','2',Path(decode_dir,'result.txt')])
print(er_rate)
cer = run_cmd(['grep','-e','Sum/Avg','-m','1',Path(decode_dir,'result.txt')])
cer = re.sub(' +',' ',cer).split(' ')[10]
comet_exp.log_other(f"cer({paras.decode_mode})",cer)
logger.log(f"CER: {cer}", prefix='test')
with open(Path(decode_dir,'cer'),'w') as fout:
print(str(cer), file=fout)
def wer_filt(s):
ret = list(map(int,s.split(' ')))
ret = [i for i in ret if i < len(id2ch)] # to filter out <sos> <eos> idx
ret = [id2ch[i] for i in ret if i not in non_syms_ids]
# ret = [id2ch[i] for i in ret ]
return ''.join(ret)
logger.notice("WER calculating...")
### cal WER
id2ch[space_idx] = ' '
with open(Path(decode_dir,'best-hyp'),'r') as hyp_ref_in, \
open(Path(decode_dir,'hyp-word.trn'),'w') as hyp_out, \
open(Path(decode_dir,'ref-word.trn'),'w') as ref_out:
for i,line in enumerate(hyp_ref_in.readlines()):
foo = line.rstrip().split('\t')
if len(foo) == 1:
print(f"{wer_filt(foo[0])} ({i//1000}k_{i})", file=ref_out)
print(f"({i//1000}k_{i})", file=hyp_out)
# print(f"{wer_filt(foo[0])} (k_{i})", file=ref_out)
# print(f"(k_{i})", file=hyp_out)
elif len(foo) == 2:
ref = foo[0]
hyp = foo[1]
print(f"{wer_filt(ref)} ({i//1000}k_{i})", file=ref_out)
print(f"{wer_filt(hyp)} ({i//1000}k_{i})", file=hyp_out)
# print(f"{wer_filt(ref)} (k_{i})", file=ref_out)
# print(f"{wer_filt(hyp)} (k_{i})", file=hyp_out)
else:
raise ValueError("at most only ref and hyp")
res = run_cmd(['sclite','-r',Path(decode_dir,'ref-word.trn'),'trn', '-h', Path(decode_dir,'hyp-word.trn'),'trn','-i','rm','-o','all','stdout'])
logger.log(f"Write result to {Path(decode_dir,'result.wrd.txt')}",prefix='info')
with open(Path(decode_dir,'result.wrd.txt'),'w') as fout:
print(res, file=fout)
er_rate = run_cmd(['grep','-e','Avg','-e','SPKR','-m','2',Path(decode_dir,'result.wrd.txt')])
print(er_rate)
wer = run_cmd(['grep','-e','Sum/Avg','-m','1',Path(decode_dir,'result.wrd.txt')])
wer = re.sub(' +',' ',wer).split(' ')[10]
logger.log(f"WER: {wer}", prefix='test')
comet_exp.log_other(f"wer({paras.decode_mode})",wer)
with open(Path(decode_dir,'wer'),'w') as fout:
print(str(wer), file=fout)
comet_exp.log_other('status','completed')
wc = run_cmd(['wc','-l',Path(decode_dir,'best-hyp')])
wc = wc.split(' ')[0]
comet_exp.log_other(f"#decode",wc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment