Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Using an explicit state table to extract internal paths from composed FST (OpenFst)
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2011 Paul R. Dixon
// Author: paul@edobashira.com
//
#include <iostream>
#include <fst/compose.h>
#include <fst/shortest-path.h>
using namespace std;
using namespace fst;
//Compile on Linux
//g++ decompose.cc -O2 -o decompose -lfst -ldl
//This example shows how to supply a user instantiated state table
//into the composition options and access the composition tuples
//after composing n fsts. The composition state table can be used
//to recover the internal path
//The program will read n fsts, compose them and run a modified shortestpath
//Then use a manually supplied state table to access the composition
//states and recover the best paths through the component
//fsts
//Cut and paste job on the standard OpenFst shortest path
//function, added the best vector which contains the best
//state sequence in reverse order through ifst
template<class Arc, class Queue, class ArcFilter>
void SingleShortestPath(const Fst<Arc> &ifst,
MutableFst<Arc> *ofst,
vector<typename Arc::Weight> *distance,
vector<typename Arc::StateId>* best,
ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
ofst->DeleteStates();
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
if (ifst.Start() == kNoStateId)
return;
vector<char> enqueued; //Changed this to char faster on Windows
vector<StateId> parent;
vector<Arc> arc_parent;
Queue *state_queue = opts.state_queue;
StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
Weight f_distance = Weight::Zero();
StateId f_parent = kNoStateId;
distance->clear();
state_queue->Clear();
if (opts.nshortest != 1)
LOG(FATAL) << "SingleShortestPath: for nshortest > 1, use ShortestPath"
<< " instead";
if (opts.weight_threshold != Weight::Zero() ||
opts.state_threshold != kNoStateId)
LOG(FATAL) <<
"SingleShortestPath: weight and state thresholds not applicable";
if ((Weight::Properties() & (kPath | kRightSemiring))
!= (kPath | kRightSemiring))
LOG(FATAL) << "SingleShortestPath: Weight needs to have the path"
<< " property and be right distributive: " << Weight::Type();
while (distance->size() < source) {
distance->push_back(Weight::Zero());
enqueued.push_back(false);
parent.push_back(kNoStateId);
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
}
distance->push_back(Weight::One());
parent.push_back(kNoStateId);
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
state_queue->Enqueue(source);
enqueued.push_back(true);
while (!state_queue->Empty()) {
StateId s = state_queue->Head();
state_queue->Dequeue();
enqueued[s] = false;
Weight sd = (*distance)[s];
if (ifst.Final(s) != Weight::Zero()) {
Weight w = Times(sd, ifst.Final(s));
if (f_distance != Plus(f_distance, w)) {
f_distance = Plus(f_distance, w);
f_parent = s;
}
if (opts.first_path)
break;
}
for (ArcIterator< Fst<Arc> > aiter(ifst, s);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
while (distance->size() <= arc.nextstate) {
distance->push_back(Weight::Zero());
enqueued.push_back(false);
parent.push_back(kNoStateId);
arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
kNoStateId));
}
Weight &nd = (*distance)[arc.nextstate];
Weight w = Times(sd, arc.weight);
if (nd != Plus(nd, w)) {
nd = Plus(nd, w);
parent[arc.nextstate] = s;
arc_parent[arc.nextstate] = arc;
if (!enqueued[arc.nextstate]) {
state_queue->Enqueue(arc.nextstate);
enqueued[arc.nextstate] = true;
} else {
state_queue->Update(arc.nextstate);
}
}
}
}
StateId s_p = kNoStateId, d_p = kNoStateId;
for (StateId s = f_parent, d = kNoStateId;
s != kNoStateId;
d = s, s = parent[s]) {
d_p = s_p;
s_p = ofst->AddState();
best->push_back(s);
if (d == kNoStateId) {
ofst->SetFinal(s_p, ifst.Final(f_parent));
} else {
arc_parent[d].nextstate = d_p;
ofst->AddArc(s_p, arc_parent[d]);
}
}
ofst->SetStart(s_p);
ofst->SetProperties(
ShortestPathProperties(ofst->Properties(kFstProperties, false)),
kFstProperties);
}
//path is vector containing the path through the composed
//Fst. Returns a pair containting the score contribution from each
//Transducer and a pair of paths through each machine
//WARNING: this might not be correct if there are multiple
//arcs with same labels between two identical states
template<class Arc, class T>
pair<typename Arc::Weight, typename Arc::Weight>
Decompose(const Fst<Arc>& fst1, const Fst<Arc>& fst2,
const vector<typename Arc::StateId>& path, const T& state_table,
vector<typename Arc::StateId>* seq1, vector<typename Arc::StateId>* seq2) {
typedef typename Arc::Weight W;
typedef typename Arc::StateId S;
//
W p1 = fst1.Final(state_table.Tuple(path[path.size() - 1]).state_id1);
W p2 = fst2.Final(state_table.Tuple(path[path.size() - 1]).state_id2);
seq1->push_back(state_table.Tuple(path[0]).state_id1);
seq2->push_back(state_table.Tuple(path[0]).state_id1);
for (int i = 1; i != path.size(); ++i) {
S s1 = state_table.Tuple(path[i - 1]).state_id2;
S s2 = state_table.Tuple(path[i]).state_id2;
W w = W::Zero();
for (ArcIterator< Fst<Arc> > ai(fst2, s1); !ai.Done(); ai.Next()) {
const Arc& arc = ai.Value();
if (arc.nextstate == s2)
w = Plus(w, arc.weight);
}
if (w != W::Zero()) {
p2 = Times(p2, w);
} // else Epsilon move in the other transducer
seq2->push_back(s2);
s1 = state_table.Tuple(path[i - 1]).state_id1;
s2 = state_table.Tuple(path[i]).state_id1;
w = W::Zero();
for (ArcIterator< Fst<Arc> > ai(fst1, s1); !ai.Done(); ai.Next()) {
const Arc& arc = ai.Value();
if (arc.nextstate == s2)
w = Plus(w, arc.weight);
}
if (w != W::Zero()) {
p1 = Times(p1, w);
} //else Epsilon move in the other transducer
seq1->push_back(s2);
}
return make_pair(p1, p2);
}
int main(int argc, char** argv) {
//Compose two fsts run best-path and decompose the best path
//into the component sequences
//These are the aliases for all default composition options
typedef StdArc A;
typedef A::StateId S;
typedef A::Weight W;
typedef Matcher< Fst<A> > M;
typedef SequenceComposeFilter<M> F;
typedef GenericComposeStateTable<A, F::FilterState> T;
typedef ComposeFstOptions<A, M, F, T> COpts;
if (argc <= 1) {
cerr << "Usage : cmd 1st_fst 2nd_fst .. nth_fst" << endl;
exit(0);
}
vector<StdFst*> fsts;
vector<COpts*> copts;
for (int i = 1; i != argc; ++i) {
StdFst* fst = StdFst::Read(argv[i]);
if (fst == NULL)
LOG(FATAL) << "Failed to read fst from : " << argv[i];
fsts.push_back(fst);
}
vector<StdFst*> cfsts;
cfsts.push_back(fsts[0]);
for (int i = 1; i != fsts.size(); i++) {
COpts* opts = new COpts();
opts->state_table = new T(*cfsts[i - 1], *fsts[i]);
copts.push_back(opts);
//Don't instantiate ComposeFst as an argument to shortestpath function
//because the destructor will erase the state tables contents.
ComposeFst<A>* cfst = new ComposeFst<A>(*cfsts[i - 1], *fsts[i], *opts);
cfsts.push_back(cfst);
}
StdVectorFst bfst;
vector<W> distance;
vector<S> best;
typedef AutoQueue<S> Q;
AnyArcFilter<A> filter;
Q q(*cfsts.back(), &distance, filter);
ShortestPathOptions<A, Q, AnyArcFilter<A> > spopts(&q, filter);
SingleShortestPath(*cfsts.back(), &bfst, &distance, &best, spopts);
reverse(best.begin(), best.end());
LOG(INFO) << "Best sequence length " << bfst.NumStates();
pair<W, W> p;
for (int i = copts.size(); i > 0; i--)
{
vector<S> best1;
vector<S> best2;
StdFst& fst1 = *cfsts[i - 1];
StdFst& fst2 = *fsts[i];
T& t = *copts[i - 1]->state_table;
p = Decompose(fst1, fst2, best,
t, &best1, &best2);
LOG(INFO) << "Fst " << i + 1 << " score contribution " <<
p.second.Value() - p.first.Value();
stringstream ss;
for (size_t j = 0; j != best2.size(); ++j)
ss << best2[j] << " ";
LOG(INFO) << "Fst " << i + 1 << " best state sequence : " << ss.str();
best = best1;
}
LOG(INFO) << "Fst " << 1 << " score contribution " << p.first;
stringstream ss;
for (size_t i = 0; i != best.size(); ++i)
ss << best[i] << " ";
LOG(INFO) << "Fst " << 1 << " best state sequence : " << ss.str();
for (int i = copts.size() - 1; i >= 0; --i) {
delete cfsts[i + 1];
delete copts[i];
}
for (int i = 0; i != fsts.size(); ++i)
delete fsts[i];
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment