Created
April 23, 2011 03:43
-
-
Save edobashira/938238 to your computer and use it in GitHub Desktop.
Using an explicit state table to extract internal paths from composed FST (OpenFst)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
// | |
// Copyright 2011 Paul R. Dixon | |
// Author: paul@edobashira.com | |
// | |
#include <iostream> | |
#include <fst/compose.h> | |
#include <fst/shortest-path.h> | |
using namespace std; | |
using namespace fst; | |
//Compile on Linux | |
//g++ decompose.cc -O2 -o decompose -lfst -ldl | |
//This example shows how to supply a user instantiated state table | |
//into the composition options and access the composition tuples | |
//after composing n fsts. The composition state table can be used | |
//to recover the internal path | |
//The program will read n fsts, compose them and run a modified shortestpath | |
//Then use a manually supplied state table to access the composition | |
//states and recover the best paths through the component | |
//fsts | |
//Cut and paste job on the standard OpenFst shortest path | |
//function, added the best vector which contains the best | |
//state sequence in reverse order through ifst | |
template<class Arc, class Queue, class ArcFilter> | |
void SingleShortestPath(const Fst<Arc> &ifst, | |
MutableFst<Arc> *ofst, | |
vector<typename Arc::Weight> *distance, | |
vector<typename Arc::StateId>* best, | |
ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { | |
typedef typename Arc::StateId StateId; | |
typedef typename Arc::Weight Weight; | |
ofst->DeleteStates(); | |
ofst->SetInputSymbols(ifst.InputSymbols()); | |
ofst->SetOutputSymbols(ifst.OutputSymbols()); | |
if (ifst.Start() == kNoStateId) | |
return; | |
vector<char> enqueued; //Changed this to char faster on Windows | |
vector<StateId> parent; | |
vector<Arc> arc_parent; | |
Queue *state_queue = opts.state_queue; | |
StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source; | |
Weight f_distance = Weight::Zero(); | |
StateId f_parent = kNoStateId; | |
distance->clear(); | |
state_queue->Clear(); | |
if (opts.nshortest != 1) | |
LOG(FATAL) << "SingleShortestPath: for nshortest > 1, use ShortestPath" | |
<< " instead"; | |
if (opts.weight_threshold != Weight::Zero() || | |
opts.state_threshold != kNoStateId) | |
LOG(FATAL) << | |
"SingleShortestPath: weight and state thresholds not applicable"; | |
if ((Weight::Properties() & (kPath | kRightSemiring)) | |
!= (kPath | kRightSemiring)) | |
LOG(FATAL) << "SingleShortestPath: Weight needs to have the path" | |
<< " property and be right distributive: " << Weight::Type(); | |
while (distance->size() < source) { | |
distance->push_back(Weight::Zero()); | |
enqueued.push_back(false); | |
parent.push_back(kNoStateId); | |
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); | |
} | |
distance->push_back(Weight::One()); | |
parent.push_back(kNoStateId); | |
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); | |
state_queue->Enqueue(source); | |
enqueued.push_back(true); | |
while (!state_queue->Empty()) { | |
StateId s = state_queue->Head(); | |
state_queue->Dequeue(); | |
enqueued[s] = false; | |
Weight sd = (*distance)[s]; | |
if (ifst.Final(s) != Weight::Zero()) { | |
Weight w = Times(sd, ifst.Final(s)); | |
if (f_distance != Plus(f_distance, w)) { | |
f_distance = Plus(f_distance, w); | |
f_parent = s; | |
} | |
if (opts.first_path) | |
break; | |
} | |
for (ArcIterator< Fst<Arc> > aiter(ifst, s); | |
!aiter.Done(); | |
aiter.Next()) { | |
const Arc &arc = aiter.Value(); | |
while (distance->size() <= arc.nextstate) { | |
distance->push_back(Weight::Zero()); | |
enqueued.push_back(false); | |
parent.push_back(kNoStateId); | |
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), | |
kNoStateId)); | |
} | |
Weight &nd = (*distance)[arc.nextstate]; | |
Weight w = Times(sd, arc.weight); | |
if (nd != Plus(nd, w)) { | |
nd = Plus(nd, w); | |
parent[arc.nextstate] = s; | |
arc_parent[arc.nextstate] = arc; | |
if (!enqueued[arc.nextstate]) { | |
state_queue->Enqueue(arc.nextstate); | |
enqueued[arc.nextstate] = true; | |
} else { | |
state_queue->Update(arc.nextstate); | |
} | |
} | |
} | |
} | |
StateId s_p = kNoStateId, d_p = kNoStateId; | |
for (StateId s = f_parent, d = kNoStateId; | |
s != kNoStateId; | |
d = s, s = parent[s]) { | |
d_p = s_p; | |
s_p = ofst->AddState(); | |
best->push_back(s); | |
if (d == kNoStateId) { | |
ofst->SetFinal(s_p, ifst.Final(f_parent)); | |
} else { | |
arc_parent[d].nextstate = d_p; | |
ofst->AddArc(s_p, arc_parent[d]); | |
} | |
} | |
ofst->SetStart(s_p); | |
ofst->SetProperties( | |
ShortestPathProperties(ofst->Properties(kFstProperties, false)), | |
kFstProperties); | |
} | |
//path is vector containing the path through the composed | |
//Fst. Returns a pair containting the score contribution from each | |
//Transducer and a pair of paths through each machine | |
//WARNING: this might not be correct if there are multiple | |
//arcs with same labels between two identical states | |
template<class Arc, class T> | |
pair<typename Arc::Weight, typename Arc::Weight> | |
Decompose(const Fst<Arc>& fst1, const Fst<Arc>& fst2, | |
const vector<typename Arc::StateId>& path, const T& state_table, | |
vector<typename Arc::StateId>* seq1, vector<typename Arc::StateId>* seq2) { | |
typedef typename Arc::Weight W; | |
typedef typename Arc::StateId S; | |
// | |
W p1 = fst1.Final(state_table.Tuple(path[path.size() - 1]).state_id1); | |
W p2 = fst2.Final(state_table.Tuple(path[path.size() - 1]).state_id2); | |
seq1->push_back(state_table.Tuple(path[0]).state_id1); | |
seq2->push_back(state_table.Tuple(path[0]).state_id1); | |
for (int i = 1; i != path.size(); ++i) { | |
S s1 = state_table.Tuple(path[i - 1]).state_id2; | |
S s2 = state_table.Tuple(path[i]).state_id2; | |
W w = W::Zero(); | |
for (ArcIterator< Fst<Arc> > ai(fst2, s1); !ai.Done(); ai.Next()) { | |
const Arc& arc = ai.Value(); | |
if (arc.nextstate == s2) | |
w = Plus(w, arc.weight); | |
} | |
if (w != W::Zero()) { | |
p2 = Times(p2, w); | |
} // else Epsilon move in the other transducer | |
seq2->push_back(s2); | |
s1 = state_table.Tuple(path[i - 1]).state_id1; | |
s2 = state_table.Tuple(path[i]).state_id1; | |
w = W::Zero(); | |
for (ArcIterator< Fst<Arc> > ai(fst1, s1); !ai.Done(); ai.Next()) { | |
const Arc& arc = ai.Value(); | |
if (arc.nextstate == s2) | |
w = Plus(w, arc.weight); | |
} | |
if (w != W::Zero()) { | |
p1 = Times(p1, w); | |
} //else Epsilon move in the other transducer | |
seq1->push_back(s2); | |
} | |
return make_pair(p1, p2); | |
} | |
int main(int argc, char** argv) { | |
//Compose two fsts run best-path and decompose the best path | |
//into the component sequences | |
//These are the aliases for all default composition options | |
typedef StdArc A; | |
typedef A::StateId S; | |
typedef A::Weight W; | |
typedef Matcher< Fst<A> > M; | |
typedef SequenceComposeFilter<M> F; | |
typedef GenericComposeStateTable<A, F::FilterState> T; | |
typedef ComposeFstOptions<A, M, F, T> COpts; | |
if (argc <= 1) { | |
cerr << "Usage : cmd 1st_fst 2nd_fst .. nth_fst" << endl; | |
exit(0); | |
} | |
vector<StdFst*> fsts; | |
vector<COpts*> copts; | |
for (int i = 1; i != argc; ++i) { | |
StdFst* fst = StdFst::Read(argv[i]); | |
if (fst == NULL) | |
LOG(FATAL) << "Failed to read fst from : " << argv[i]; | |
fsts.push_back(fst); | |
} | |
vector<StdFst*> cfsts; | |
cfsts.push_back(fsts[0]); | |
for (int i = 1; i != fsts.size(); i++) { | |
COpts* opts = new COpts(); | |
opts->state_table = new T(*cfsts[i - 1], *fsts[i]); | |
copts.push_back(opts); | |
//Don't instantiate ComposeFst as an argument to shortestpath function | |
//because the destructor will erase the state tables contents. | |
ComposeFst<A>* cfst = new ComposeFst<A>(*cfsts[i - 1], *fsts[i], *opts); | |
cfsts.push_back(cfst); | |
} | |
StdVectorFst bfst; | |
vector<W> distance; | |
vector<S> best; | |
typedef AutoQueue<S> Q; | |
AnyArcFilter<A> filter; | |
Q q(*cfsts.back(), &distance, filter); | |
ShortestPathOptions<A, Q, AnyArcFilter<A> > spopts(&q, filter); | |
SingleShortestPath(*cfsts.back(), &bfst, &distance, &best, spopts); | |
reverse(best.begin(), best.end()); | |
LOG(INFO) << "Best sequence length " << bfst.NumStates(); | |
pair<W, W> p; | |
for (int i = copts.size(); i > 0; i--) | |
{ | |
vector<S> best1; | |
vector<S> best2; | |
StdFst& fst1 = *cfsts[i - 1]; | |
StdFst& fst2 = *fsts[i]; | |
T& t = *copts[i - 1]->state_table; | |
p = Decompose(fst1, fst2, best, | |
t, &best1, &best2); | |
LOG(INFO) << "Fst " << i + 1 << " score contribution " << | |
p.second.Value() - p.first.Value(); | |
stringstream ss; | |
for (size_t j = 0; j != best2.size(); ++j) | |
ss << best2[j] << " "; | |
LOG(INFO) << "Fst " << i + 1 << " best state sequence : " << ss.str(); | |
best = best1; | |
} | |
LOG(INFO) << "Fst " << 1 << " score contribution " << p.first; | |
stringstream ss; | |
for (size_t i = 0; i != best.size(); ++i) | |
ss << best[i] << " "; | |
LOG(INFO) << "Fst " << 1 << " best state sequence : " << ss.str(); | |
for (int i = copts.size() - 1; i >= 0; --i) { | |
delete cfsts[i + 1]; | |
delete copts[i]; | |
} | |
for (int i = 0; i != fsts.size(); ++i) | |
delete fsts[i]; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment