Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Matrix completion run
#include <mlpack/core.hpp>
#include <mlpack/core/optimizers/parallel_sgd/decay_policies/constant_step.hpp>
#include <mlpack/core/optimizers/parallel_sgd/parallel_sgd.hpp>
#include <mlpack/core/optimizers/parallel_sgd/sparse_mc_function.hpp>
using namespace std;
using namespace mlpack;
using namespace mlpack::optimization;
int main(int argc, char* argv[])
{
mlpack::CLI::ParseCommandLine(argc, argv);
arma::sp_mat ans_mat("5 3 0 1; 4 0 0 1; 1 1 0 5; 1 0 0 4; 0 1 5 4;");
SparseMCLossFunction f(ans_mat, 0.02, 4);
ConstantStep decayPolicy(0.1);
ParallelSGD<ConstantStep> s(50000, 5, 1e-5, decayPolicy);
arma::mat iterate(4, 9, arma::fill::randn);
double initial = 0;
for(size_t i = 0; i < f.NumFunctions(); ++i)
{
initial += f.Evaluate(iterate, i);
}
cout << "Initial objective : " << initial << endl;
s.Optimize(f, iterate);
//cout << iterate << endl;
arma::mat gen(arma::trans(iterate.cols(0, 4)) * iterate.cols(5, 8));
cout << arma::mat(ans_mat) << endl;
cout << gen << endl;
double final = 0;
for(size_t i = 0; i < f.NumFunctions(); ++i)
{
final += f.Evaluate(iterate, i);
}
cout << "Final objective : " << final << endl;
}
#############################################################
Without mean
#############################################################
λ mlpack_spike/hogwild ∴ OMP_NUM_THREADS=1 ./a.out
Initial objective : 124.529
5.0000 3.0000 0 1.0000
4.0000 0 0 1.0000
1.0000 1.0000 0 5.0000
1.0000 0 0 4.0000
0 1.0000 5.0000 4.0000
4.9927 2.9825 0.7369 0.9968
3.9875 1.1196 -1.5995 0.9973
1.0016 0.9976 1.0712 4.9785
1.0091 -0.7997 -0.0219 3.9911
0.2612 0.9937 4.9890 3.9918
Final objective : 0.287676
#############################################################
With mean
#############################################################
λ mlpack_spike/hogwild ∴ OMP_NUM_THREADS=1 ./a.out
Initial objective : 30.5357
5.0000 3.0000 0 1.0000
4.0000 0 0 1.0000
1.0000 1.0000 0 5.0000
1.0000 0 0 4.0000
0 1.0000 5.0000 4.0000
2.2052 0.2323 1.3727 -1.7562
1.2226 0.8786 -1.3400 -1.7594
-1.7648 -1.7474 1.5708 2.2262
-1.7624 -0.9664 -0.8608 1.2239
-0.6352 -1.7405 2.2264 1.2243
Final objective : 0.226394
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment