Skip to content

Instantly share code, notes, and snippets.

@AdolfVonKleist
Last active May 29, 2020 08:50
Show Gist options
  • Save AdolfVonKleist/11108535 to your computer and use it in GitHub Desktop.
Save AdolfVonKleist/11108535 to your computer and use it in GitHub Desktop.
compute-best-mix.py : Python port of the venerated SRILM tool
#!/usr/bin/python
import re, math
def LoadPPLFile (pplfile) :
"""
Load up the salient info from a -debug 2 PPL file
generated by the SRILM ngram tool.
"""
ppl_info = []
for line in open (pplfile, "r") :
line = line.strip()
if line.startswith ("p(") :
tok, prob = line.split ("=")
probs = re.split (r"\s+", re.sub (r"[\[\]]", "", prob).strip())
tok = re.sub (r"^p\( ", "", tok)
tok = re.sub (r" \|.*$", "", tok)
ppl_info.append ([tok, int(re.sub(r"gram", "",probs[0])), float(probs[2])])
print len (ppl_info)
print "\t", ppl_info[2]
return ppl_info
class MixtureComputer () :
"""
Python port of SRILM gawk tool 'compute-best-mix'
Should produce the same result as:
$ compute-best-mix ppl.fil1 ppl.file2 ... ppl.fileN
"""
def __init__ (self, ppl_infos, lambdas=[], precision=0.001, verbose=False) :
self.M_LN10 = 2.30258509299404568402
self.logINF = -320
self.precision = precision
self.ppl_infos = ppl_infos
self.lambdas, self.priors = self._init_lambdas (lambdas)
self.max_iter = 100
self.post_totals = []
self.nc = len (self.ppl_infos) #Number of components
self.log10priors = []
def _init_lambdas (self, lambdas) :
if len (lambdas) == 0 :
lambdas = [1./len(self.ppl_infos) for l in xrange(len(self.ppl_infos))]
lambda_sum = 0.0
priors_ = [0.0 for l in lambdas]
for i,l in enumerate (lambdas) :
priors_[i] = l
lambda_sum += l
return lambdas, priors_
def _sum_word (self, i) :
log_posts_ = [self.log10priors[j] + self.ppl_infos[j][i][2]
for j in xrange(self.nc)]
log_sum_ = log_posts_[0]
for log_post_ in log_posts_[1:] :
log_sum_ = math.log (
(math.pow (10, log_sum_) + math.pow (10, log_post_)),
10)
return log_sum_, log_posts_
def OptimizeLambdas (self) :
"""
So how does this actually work? There is no explanation except
the source where the original gawk script is concerned.
It is basically a simple, iterative EM-like estimation procedure.
1. Load all the PPL results from the component models
2. Initialize the original mixture weights
3. For each word in the test set, compute the lambda-scaled sum
for each of the component models for this word.
For example,
* word = WORD1,
* models = M1, M2, M3
* lambdas = L1, L2, L3
Compute log posteriors: log10(LN) + WORD1
Compute the log sum of the posteriors for this word.
4. Compute the per-model posterior totals
This is the per-model log posterior from (3.)
divided by the total (word-based) log sum from (3.)
5. Recompute the lambda priors, normalizing by the total
number of (non-OOV) words in the test set
6. Finally, determine the actual, absolute change between
the previous prior values, and the newly recomputed ones.
If the values for any of the models is larger than the
precision threshold, and we have not reached the max
number of iterations, return to Step 3.
The algorithm terminates when either the max-iters is reached
or the total change for all models dips below the threshold.
"""
have_converged = False
iteration = 0
while not have_converged :
iteration += 1
log_like = 0.0
post_totals = [0.0 for p in self.ppl_infos]
self.log10priors = [math.log (self.priors[i], 10)
for i in xrange(self.nc)]
for i in xrange(len(self.ppl_infos[0])) :
# Compute the sum for this word, across all components
log_sum, log_posts = self._sum_word (i)
log_like += log_sum
for j in xrange(len(self.ppl_infos)) :
post_totals[j] += math.pow (10, log_posts[j] - log_sum)
print iteration, \
" ".join([str(x) for x in self.priors]), \
math.pow (10, -log_like / len(self.ppl_infos[0]))
have_converged = True
for j in xrange(len(self.ppl_infos)) :
last_prior = self.priors[j]
self.priors[j] = post_totals[j] / len(self.ppl_infos[0])
abs_change = abs (last_prior - self.priors[j])
if abs_change > self.precision :
have_converged = False
if iteration > self.max_iter :
have_converged = True
return
if __name__=="__main__" :
import sys, argparse
example = "USAGE: {0} --ppl ppl.1.txt,ppl.2.txt,ppl.3.txt".format (sys.argv[0])
parser = argparse.ArgumentParser (description = example)
parser.add_argument ("--ppl", "-p", help="List of ppl files from 'ngram'.", required=True)
parser.add_argument ("--verbose", "-v", help="Verbose mode.", default=False, action="store_true")
args = parser.parse_args ()
pplfiles = args.ppl.split (",")
pplinfos = []
for f in pplfiles :
pplinfos.append (LoadPPLFile (f))
mixer = MixtureComputer (pplinfos)
mixer.OptimizeLambdas ()
print mixer.priors
@luuik
Copy link

luuik commented Sep 29, 2015

Hello. Looks cool.
But right now I'm getting:
$ compute-best-mix.py --ppl ppl1,ppl2
Traceback (most recent call last):
File "compute-best-mix.py", line 137, in
pplinfos.append (LoadPPLFile (f))
File "compute-best-mix.py", line 18, in LoadPPLFile
ppl_info.append ([tok, int(re.sub(r"gram", "",probs[0])), float(probs[2])])
ValueError: invalid literal for int() with base 10: 'OOV'

@dansoutner
Copy link

You need to change line 18 something like
ppl_info.append ([tok, int(re.sub(r"gram", "",probs[0].replace("OOV", "0"))), float(probs[2])])

Also it is good to deal with -inf if you have OOVs in your text.

@ChuanTianML
Copy link

Hello,do you know that what is the weight updating method using in compute-best-mix shell ? gradient descent ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment