Created
February 23, 2021 02:11
-
-
Save sunprinceS/d5288cb14b4f0ab4404254b84148277c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import re | |
import yaml | |
import json | |
from pathlib import Path | |
from comet_ml import ExistingExperiment | |
from src.marcos import * | |
from src.utils import run_cmd | |
import src.monitor.logger as logger | |
parser = argparse.ArgumentParser(description='Evaluate the decoded hypothesis via sclite') | |
parser.add_argument('--config', type=str, help='Path to config file', required=True) | |
parser.add_argument('--lang', choices=AVAIL_LANGS, required=True) | |
parser.add_argument('--algo', required=True, choices=['no','reptile','multi','fomaml']) | |
parser.add_argument('--model_name', required=True, choices=['blstm','las']) | |
parser.add_argument('--eval_suffix', default=None, type=str, help='Evaluation suffix', required=True) | |
parser.add_argument('--decode_mode', default='greedy', type=str, choices=['greedy','beam','lm_beam']) | |
parser.add_argument('--pretrain_suffix',type=str, help='Pretrain model suffix', nargs='?',const=None) | |
parser.add_argument('--runs', type=int, help='run suffix', nargs='?',const=0) | |
parser.add_argument('--decode_suffix', default=None, type=str) # will remove later | |
paras = parser.parse_args() | |
decode_suffix = f"{paras.decode_mode}_decode_{paras.decode_suffix}" if paras.decode_suffix else f"{paras.decode_mode}_decode" | |
# Load some prerequisites | |
with open(Path('data','lang.json'),'r') as fin: | |
id2lang = json.load(fin) | |
config = yaml.safe_load(open(paras.config,'r')) | |
if paras.pretrain_suffix is None: | |
if paras.algo == 'no': | |
paras.pretrain_suffix = paras.eval_suffix | |
else: | |
assert False, "pretrain_suffix should be specified if using pretrained model" | |
# Write to corresponding directory and comet exp | |
cur_path = Path.cwd() | |
log_dir = Path(cur_path, LOG_DIR, 'evaluation', | |
config['solver']['setting'], paras.algo, \ | |
paras.pretrain_suffix, paras.eval_suffix, \ | |
id2lang[paras.lang], str(paras.runs)) | |
decode_dir = log_dir.joinpath(decode_suffix) | |
with open(Path(log_dir,'exp_key'),'r') as f: | |
exp_key = f.read().strip() | |
comet_exp = ExistingExperiment(previous_experiment=exp_key, | |
project_name=COMET_PROJECT_NAME, | |
workspace=COMET_WORKSPACE, | |
auto_output_logging=None, | |
auto_metric_logging=None, | |
display_summary=False, | |
) | |
id2ch = dict() | |
ch2id = dict() | |
id2ch[0] = '' | |
ch2id[''] = 0 | |
with open(Path(config['solver']['data_root'],id2lang[paras.lang],'train_units.txt')) as fin: | |
for line in fin.readlines(): | |
ch,i = line.split(' ') | |
id2ch[int(i)] = ch | |
ch2id[ch] = int(i) | |
non_syms = list() | |
with open(Path(config['solver']['data_root'],id2lang[paras.lang],'non_lang_syms.txt')) as fin: | |
for line in fin.readlines(): | |
non_syms.append(line.rstrip()) | |
non_syms_ids = [ch2id[ch] for ch in non_syms] | |
space_idx = ch2id['<space>'] | |
### CER calculation | |
def cer_filt(s): | |
ret = list(map(int,s.split(' '))) | |
ret = [i for i in ret if i < len(id2ch)] # to filter out <sos> <eos> idx | |
ret = [id2ch[i] for i in ret if i not in non_syms_ids] | |
# ret = [id2ch[i] for i in ret] | |
# print(ret) | |
return ' '.join(ret) | |
### cal CER | |
logger.notice("CER calculating...") | |
with open(Path(decode_dir,'best-hyp'),'r') as hyp_ref_in, \ | |
open(Path(decode_dir,'hyp.trn'),'w') as hyp_out, \ | |
open(Path(decode_dir,'ref.trn'),'w') as ref_out: | |
for i,line in enumerate(hyp_ref_in.readlines()): | |
foo = line.rstrip().split('\t') | |
if len(foo) == 1: | |
print(f"{cer_filt(foo[0])} ({i//1000}k_{i})", file=ref_out) | |
print(f"({i//1000}k_{i})", file=hyp_out) | |
elif len(foo) == 2: | |
ref = foo[0] | |
hyp = foo[1] | |
print(f"{cer_filt(ref)} ({i//1000}k_{i})", file=ref_out) | |
print(f"{cer_filt(hyp)} ({i//1000}k_{i})", file=hyp_out) | |
else: | |
raise ValueError("at most only ref and hyp") | |
res = run_cmd(['sclite','-r',Path(decode_dir,'ref.trn'),'trn', '-h', Path(decode_dir,'hyp.trn'),'trn','-i','rm','-o','all','stdout']) | |
logger.log(f"Write result to {Path(decode_dir,'result.txt')}",prefix='info') | |
with open(Path(decode_dir,'result.txt'),'w') as fout: | |
print(res, file=fout) | |
er_rate = run_cmd(['grep','-e','Avg','-e','SPKR','-m','2',Path(decode_dir,'result.txt')]) | |
print(er_rate) | |
cer = run_cmd(['grep','-e','Sum/Avg','-m','1',Path(decode_dir,'result.txt')]) | |
cer = re.sub(' +',' ',cer).split(' ')[10] | |
comet_exp.log_other(f"cer({paras.decode_mode})",cer) | |
logger.log(f"CER: {cer}", prefix='test') | |
with open(Path(decode_dir,'cer'),'w') as fout: | |
print(str(cer), file=fout) | |
def wer_filt(s): | |
ret = list(map(int,s.split(' '))) | |
ret = [i for i in ret if i < len(id2ch)] # to filter out <sos> <eos> idx | |
ret = [id2ch[i] for i in ret if i not in non_syms_ids] | |
# ret = [id2ch[i] for i in ret ] | |
return ''.join(ret) | |
logger.notice("WER calculating...") | |
### cal WER | |
id2ch[space_idx] = ' ' | |
with open(Path(decode_dir,'best-hyp'),'r') as hyp_ref_in, \ | |
open(Path(decode_dir,'hyp-word.trn'),'w') as hyp_out, \ | |
open(Path(decode_dir,'ref-word.trn'),'w') as ref_out: | |
for i,line in enumerate(hyp_ref_in.readlines()): | |
foo = line.rstrip().split('\t') | |
if len(foo) == 1: | |
print(f"{wer_filt(foo[0])} ({i//1000}k_{i})", file=ref_out) | |
print(f"({i//1000}k_{i})", file=hyp_out) | |
# print(f"{wer_filt(foo[0])} (k_{i})", file=ref_out) | |
# print(f"(k_{i})", file=hyp_out) | |
elif len(foo) == 2: | |
ref = foo[0] | |
hyp = foo[1] | |
print(f"{wer_filt(ref)} ({i//1000}k_{i})", file=ref_out) | |
print(f"{wer_filt(hyp)} ({i//1000}k_{i})", file=hyp_out) | |
# print(f"{wer_filt(ref)} (k_{i})", file=ref_out) | |
# print(f"{wer_filt(hyp)} (k_{i})", file=hyp_out) | |
else: | |
raise ValueError("at most only ref and hyp") | |
res = run_cmd(['sclite','-r',Path(decode_dir,'ref-word.trn'),'trn', '-h', Path(decode_dir,'hyp-word.trn'),'trn','-i','rm','-o','all','stdout']) | |
logger.log(f"Write result to {Path(decode_dir,'result.wrd.txt')}",prefix='info') | |
with open(Path(decode_dir,'result.wrd.txt'),'w') as fout: | |
print(res, file=fout) | |
er_rate = run_cmd(['grep','-e','Avg','-e','SPKR','-m','2',Path(decode_dir,'result.wrd.txt')]) | |
print(er_rate) | |
wer = run_cmd(['grep','-e','Sum/Avg','-m','1',Path(decode_dir,'result.wrd.txt')]) | |
wer = re.sub(' +',' ',wer).split(' ')[10] | |
logger.log(f"WER: {wer}", prefix='test') | |
comet_exp.log_other(f"wer({paras.decode_mode})",wer) | |
with open(Path(decode_dir,'wer'),'w') as fout: | |
print(str(wer), file=fout) | |
comet_exp.log_other('status','completed') | |
wc = run_cmd(['wc','-l',Path(decode_dir,'best-hyp')]) | |
wc = wc.split(' ')[0] | |
comet_exp.log_other(f"#decode",wc) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment