Created
December 14, 2012 17:40
-
-
Save anonymous/4287223 to your computer and use it in GitHub Desktop.
(Hopefully) fast variable elimination algorithm.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <vector> | |
#include <set> | |
#include <iostream> | |
#include <functional> | |
#include <algorithm> | |
#include <assert.h> | |
using namespace std; | |
/** Computes the lookup table for a specific factor and a sequence of counting variables. */ | |
template<class OutputIterator> | |
void compute_factor_lookup(const vector<int>& fvars, const vector<int>& domains, const vector<int>& counter_variables, OutputIterator out){ | |
const int num_vars_f = fvars.size(); | |
//prod is the stride for each variable in the given factor's table (factors must adhere to this convention) | |
auto prod = vector<int>(num_vars_f); | |
{ | |
int p = 1; | |
auto it = prod.begin(); | |
for_each(fvars.cbegin(),fvars.cend(),[&] (int vi) { | |
*it++ = p; | |
p *= domains[vi]; | |
}); | |
} | |
//now we walk over the counting variables | |
//for each earlier variable, we substract d_i * prod_i (overflow); | |
//for the current variable we add prod_i (increment) | |
//mod will only hold the negative overflow values | |
int mod = 0; | |
for_each(counter_variables.cbegin(),counter_variables.cend(),[&] (int cv) { | |
int index_into_fvars = 0; | |
for(auto it = fvars.cbegin(); it != fvars.cend() && *it !=cv; it++) | |
index_into_fvars++; | |
if(index_into_fvars < num_vars_f){ | |
//for this iteration, we add the increment | |
*out++ = mod + prod[index_into_fvars]; | |
//but after this iteration, we want to account for the reset | |
mod -= prod[index_into_fvars] * domains[cv]; | |
} else { | |
*out++ = mod; | |
} | |
}); | |
} | |
template<class Container> | |
int index_of(Container const& c, typename Container::value_type x){ | |
int idx = 0; | |
typedef decltype(c.cbegin()) Tit; | |
for(Tit it = c.cbegin(); it != c.cend(); it++){ | |
if(*it == x) | |
break; | |
idx++; | |
} | |
if(idx >= c.size()) | |
return -1; | |
else | |
return idx; | |
} | |
/** | |
* @param The variable to eliminate. | |
* @param num_factors | |
* @param factor_sizes | |
* @param f_domains For each factor i, an array of size num_factors[i], holding the | |
* @param fvals | |
* @param result | |
*/ | |
vector<double> marginalize( | |
const int v, | |
const vector<int> &dom_sizes, | |
const vector<vector<int>> &factor_vars, | |
const vector<vector<double>> &factor_values, | |
function<double(const vector<double> &)> prod , | |
function<double(const vector<double> &)> sum | |
){ | |
const int v_size = dom_sizes[v]; | |
const int num_factors = factor_vars.size(); | |
//temporary memory to hold the values we will sum with reducer | |
auto slice_result = vector<double>(v_size); | |
//used in every iteration to multiply the results of the different factors | |
auto factor_contribs = vector<double>(num_factors); | |
//find all variables in elimination clique | |
auto elim_clique_set = set<int>(); | |
for_each(factor_vars.cbegin(),factor_vars.cend(),[&] (const vector<int> &f) { | |
elim_clique_set.insert(f.cbegin(), f.cend()); | |
}); | |
auto elim_clique = vector<int>(elim_clique_set.cbegin(),elim_clique_set.cend()); | |
const int num_elim_vars = elim_clique.size(); | |
//the indices into the factors data tables | |
auto factor_index = vector<int>(num_factors,0); | |
//the lookup tables for adjusting the factor indices | |
auto factor_index_lookup = vector<vector<int>>(num_factors,vector<int>(elim_clique.size())); | |
for(int f = 0; f < num_factors; f++){ | |
compute_factor_lookup(factor_vars[f],dom_sizes,elim_clique,factor_index_lookup[f].begin()); | |
} | |
cout << "elimination factor variables:" << endl; | |
for_each(elim_clique.cbegin(),elim_clique.cend(),[] (int fv) {cout << fv << endl;}); | |
auto count_reg = vector<int>(num_elim_vars,0); | |
int result_size = 1; | |
for_each(elim_clique.cbegin(),elim_clique.cend(),[&] (int v) {result_size *= v;}); | |
auto result = vector<double>(result_size,0); | |
auto result_pointer = result.begin(); | |
//don't forget to reorder the elim var to the beginning! | |
{ | |
//order the elimination variable at the beginning | |
int ev_idx = index_of(elim_clique,v); | |
assert(ev_idx >= 0); | |
int tmp = elim_clique[0]; | |
elim_clique[0] = elim_clique[ev_idx]; | |
elim_clique[ev_idx] = tmp; | |
} | |
assert(elim_clique[0] == v); | |
//the counting loop | |
do { | |
//t will end up marking the variable in elim_clique, that gets incremented | |
//all variables ordered before t will have an overflow, and are thus reset to 0 | |
int t = 0; | |
while(t < num_elim_vars && ++count_reg[t] == dom_sizes[t]){ | |
count_reg[t] = 0; //overflow, so we reset to 0 | |
t += 1; | |
} | |
if(t >= num_elim_vars) | |
break; | |
//the first variable (elimination variable) overflew | |
//so we have one slice complete | |
if(t > 0) { | |
*result_pointer = sum(slice_result); | |
result_pointer++; | |
} | |
//calculate the factor contributions | |
for(int f = 0; f < num_factors; f++){ | |
//for each factor, we have to update its data index | |
factor_index[f] += factor_index_lookup[f][t]; | |
//and then add its value | |
factor_contribs[f] = factor_values[f][factor_index[f]]; | |
} | |
//in each loop, we have to output the result of the product to slice_result | |
slice_result[count_reg[0]] = prod(factor_contribs); | |
} while(true); | |
return result; | |
} | |
int main(void){ | |
int doms[] = {2,2,3,3}; | |
vector<int> dom_sizes = vector<int>{2,2,3,3}; | |
int f1v[] = {0,1}; | |
int f2v[] = {1,2}; | |
vector<int> f1vars = vector<int>(f1v,f1v+2); | |
vector<int> f2vars = vector<int>(f2v,f2v+2); | |
double f1d[] = {0,1.3,3,1}; | |
double f2d[] = {1,1,1.5,1,0.5,0.5}; | |
vector<double> f1data = vector<double>(f1d,f1d+4); | |
vector<double> f2data = vector<double>(f2d,f2d+6); | |
auto fvars = vector<vector<int>>(2); | |
fvars[0] = f1vars; | |
fvars[1] = f2vars; | |
vector<vector<double>> fdata = vector<vector<double>>(2); | |
fdata[0] = f1data; | |
fdata[1] = f2data; | |
cout << index_of(f2data,1.5) << endl; | |
auto result = marginalize(1,dom_sizes,fvars,fdata, | |
[](const vector<double> xs) -> double { | |
double r = 1; | |
for_each(xs.cbegin(),xs.cend(),[&] (double d) {r *= d;}); | |
return r; | |
}, | |
[](const vector<double> xs) -> double { | |
double r = 0; | |
for_each(xs.cbegin(),xs.cend(),[&] (double d) {r += d;}); | |
return r; | |
} | |
); | |
for_each(result.begin(),result.end(),[] (double x) {cout << x << endl;}); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment