Skip to content

Instantly share code, notes, and snippets.

@AdolfVonKleist
Created July 27, 2012 15:44
Show Gist options
  • Save AdolfVonKleist/3188749 to your computer and use it in GitHub Desktop.
Save AdolfVonKleist/3188749 to your computer and use it in GitHub Desktop.
Expectation and Maximization functions from M2MFstAligner
void M2MFstAligner::expectation( ){
for( int i=0; i<fsas.size(); i++ ){
//Comput Forward and Backward probabilities
ShortestDistance( fsas.at(i), &alpha );
ShortestDistance( fsas.at(i), &beta, true );
//Compute the normalized Gamma probabilities and
// update our running tally
for( StateIterator<VectorFst<LogArc> > siter(fsas.at(i)); !siter.Done(); siter.Next() ){
LogArc::StateId q = siter.Value();
for( ArcIterator<VectorFst<LogArc> > aiter(fsas.at(i),q); !aiter.Done(); aiter.Next() ){
const LogArc& arc = aiter.Value();
const LogWeight& gamma = Divide(Times(Times(alpha[q], arc.weight), beta[arc.nextstate]), beta[0]);
//Check for any BadValue results, otherwise add to the tally.
//We call this 'prev_alignment_model' which may seem misleading, but
// this conventions leads to 'alignment_model' being the final version.
if( gamma.Value()==gamma.Value() ){
prev_alignment_model[arc.ilabel] = Plus(prev_alignment_model[arc.ilabel], gamma);
total = Plus(total, gamma);
}
}
}
alpha.clear();
beta.clear();
}
}
float M2MFstAligner::maximization( bool lastiter ){
//Maximization. Simple count normalization. Probably get an improvement
// by using a more sophisticated regularization approach.
map<LogArc::Label,LogWeight>::iterator it;
float change = abs(total.Value()-prevTotal.Value());
//cout << "Total: " << total << " Change: " << abs(total.Value()-prevTotal.Value()) << endl;
prevTotal = total;
//Normalize and iterate to the next model. We apply it dynamically
// during the expectation step.
for( it=prev_alignment_model.begin(); it != prev_alignment_model.end(); it++ ){
alignment_model[(*it).first] = Divide((*it).second,total);
(*it).second = LogWeight::Zero();
}
for( int i=0; i<fsas.size(); i++ ){
for( StateIterator<VectorFst<LogArc> > siter(fsas[i]); !siter.Done(); siter.Next() ){
LogArc::StateId q = siter.Value();
for( MutableArcIterator<VectorFst<LogArc> > aiter(&fsas[i], q); !aiter.Done(); aiter.Next() ){
LogArc arc = aiter.Value();
arc.weight = alignment_model[arc.ilabel];
aiter.SetValue(arc);
}
}
}
total = LogWeight::Zero();
return change;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment