Skip to content

Instantly share code, notes, and snippets.

Last active May 1, 2016 14:17
Show Gist options
  • Save odashi/788623ad7028a1a53ad0 to your computer and use it in GitHub Desktop.
Save odashi/788623ad7028a1a53ad0 to your computer and use it in GitHub Desktop.
Minimum error-rate training for statistical machine translation
import math
import random
import sys
from argparse import ArgumentParser
from collections import defaultdict
from util.functions import trace
def parse_args():
def_epoch = 10
def_metrics = 'BLEU'
p = ArgumentParser(
description='MERT trainer',
'\n %(prog)s [options] reference nbest model'
'\n %(prog)s -h',
help='[in] reference corpus')
help='[in] Travatar hypothesis corpus')
help='[out] model prefix')
default=def_epoch, metavar='INT', type=int,
help='number of training epoch (default: %(default)d)')
default=None, metavar='FILE', type=str,
help='initial weight file (default: %(default)s)')
# default=def_epoch, metavar='STR', type=str,
# help='evaluation metrics name (default: %(default)d)')
args = p.parse_args()
# check args
if not args.epoch >= 1: raise ValueError('Not satisfy a condition --epoch >= 1')
#if not args.metrics in ['BLEU']: raise ValueError('Not satisfy a condition --epoch in {BLEU}')
return args
def gen_ref(filename):
with open(filename) as fp:
for line in fp:
yield line.split()
def gen_hyps(filename):
prev_sid = 0
hyp_batch = []
feature_batch = []
with open(filename) as fp:
for line in fp:
sid, hyp, _, feature = line.strip().split(' ||| ')
sid = int(sid)
hyp = hyp.split()
feature = [x.split('=') for x in feature.split()]
feature = defaultdict(float, {x[0]: float(x[1]) for x in feature})
if sid != prev_sid:
yield prev_sid, hyp_batch, feature_batch
prev_sid = sid
hyp_batch = []
feature_batch = []
if hyp_batch:
yield prev_sid, hyp_batch, feature_batch
def get_default_stats(N=4):
return [0 for _ in range(2 * N + 2)]
def get_bleu_stats(ref, hyp, N=4):
# stats[2n]: candidate N-gram
# stats[2n+1]: matched N-gram
stats = [0 for _ in range(2 * N)]
for n in range(len(hyp) if len(hyp) < N else N):
matched = 0
possible = defaultdict(int)
for k in range(len(ref) - n):
possible[tuple(ref[k : k + n + 1])] += 1
for k in range(len(hyp) - n):
ngram = tuple(hyp[k : k + n + 1])
if possible[ngram] > 0:
possible[ngram] -= 1
matched += 1
stats[2 * n] = len(hyp) - n
stats[2 * n + 1] = matched
return stats + [len(ref), len(hyp)]
def calculate_bleu(stats, N=4):
np = 0.0
for n in range(N):
nn = stats[2 * n + 1]
if nn == 0:
return 0.0
np += math.log(nn) - math.log(stats[2 * n])
bp = 1.0 - stats[-2] / stats[-1]
if bp > 0.0: bp = 0.0
return math.exp(np / N + bp)
def get_feature_names(filename):
keys = set()
for sid, __, feature_batch in gen_hyps(filename):
for features in feature_batch:
keys |= set(features)
trace(sid, rollback=True)
return keys
def get_grad(feature, key):
return feature[key]
def get_bias(feature, weights, keys):
return sum(feature[k] * weights[k] for k in keys)
def accum_stats(dest, src):
for i in range(len(dest)):
dest[i] += src[i]
def get_diff_stats(a, b):
return [b[i] - a[i] for i in range(len(a))]
def train(args, epoch, weights):
for target_axis in weights:
const_axis = set(weights) - {target_axis}
trace('epoch %4d: weight: %s' % (epoch, target_axis))
total_stats = get_default_stats()
diff_stats_list = []
for ref, (sid, hyp_batch, feature_batch) in zip(gen_ref(args.reference), gen_hyps(args.nbest)):
# aw+b, where w is the target weight
a_batch = [get_grad(feature, target_axis) for feature in feature_batch]
b_batch = [get_bias(feature, weights, const_axis) for feature in feature_batch]
stats_batch = [get_bleu_stats(ref, hyp) for hyp in hyp_batch]
# ordering by small gradient, and large bias
batch = sorted(zip(a_batch, b_batch, stats_batch), key=lambda x: (x[0], -x[1]))
prev_n = 0
prev_w = -1e20 # watchdog
accum_stats(total_stats, batch[0][2])
while True:
next_n = None
next_w = 1e20 # watchdog
for n in range(prev_n + 1, len(batch)):
if batch[n][0] == batch[prev_n][0]:
continue # ignore same gradients
# update intersection
w = (batch[n][1] - batch[prev_n][1]) / (batch[prev_n][0] - batch[n][0])
if prev_w < w <= next_w:
next_n = n
next_w = w
if next_n is None:
break # no more intersection
diff_stats_list.append((next_w, get_diff_stats(batch[prev_n][2], batch[next_n][2])))
prev_n = next_n
prev_w = next_w
trace(sid, rollback=True)
best_bleu = calculate_bleu(total_stats)
if len(diff_stats_list) > 0:
# find global optimum over the focused axis
diff_stats_list = sorted(diff_stats_list, key=lambda x: x[0])
best_m = -1
for m, (w, diff) in enumerate(diff_stats_list):
accum_stats(total_stats, diff)
if m == len(diff_stats_list) - 1 or w < diff_stats_list[m + 1][0]:
bleu = calculate_bleu(total_stats)
if bleu > best_bleu:
best_m = m
best_bleu = bleu
# update weight
if best_m == -1:
weights[target_axis] = diff_stats_list[0][0] - 1.0
elif best_m == len(diff_stats_list) - 1:
weights[target_axis] = diff_stats_list[-1][0] + 1.0
weights[target_axis] = \
0.5 * (diff_stats_list[best_m][0] + diff_stats_list[best_m + 1][0])
# no intersection
weights[target_axis] = 0.0
# verify
#total_stats = get_default_stats()
#all_weights = set(weights)
#for ref, (sid, hyp_batch, feature_batch) in zip(gen_ref(args.reference), gen_hyps(args.nbest)):
# best_hyp = None
# best_score = -1e20
# for i, (hyp, feature) in enumerate(zip(hyp_batch, feature_batch)):
# score = get_bias(feature, weights, all_weights)
# if score > best_score:
# best_hyp = hyp
# best_score = score
# accum_stats(total_stats, get_bleu_stats(ref, best_hyp))
#bleu_verify = calculate_bleu(total_stats)
#if best_bleu != bleu_verify:
# raise RuntimeError('abort')
trace('%8s = %+.6f, BLEU = %.6f' % (target_axis, weights[target_axis], best_bleu))
def init_weights(args):
weight = {k: random.uniform(-1, 1) for k in get_feature_names(args.nbest)}
if args.init_weight is not None:
weight = {k: 0.0 for k in weight}
with open(args.init_weight) as fp:
for line in fp:
key, value = line.split()
weight[key] = float(value)
return weight
def save(weights, filename):
with open(filename, 'w') as fp:
for k, v in weights.items():
print('%s\t%+.8e' % (k, v), file=fp)
def main():
args = parse_args()
trace('gathering weights ...')
weights = init_weights(args)
trace('start training ...')
for i in range(args.epoch):
train(args, i, weights)
save(weights, args.model + '.%04d' % (i + 1))
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment