Last active
July 3, 2017 13:49
-
-
Save shikharbhardwaj/27685f5cbb5d3a465993405b8be3fc6e to your computer and use it in GitHub Desktop.
Matrix completion run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <mlpack/core.hpp> | |
#include <mlpack/core/optimizers/parallel_sgd/decay_policies/constant_step.hpp> | |
#include <mlpack/core/optimizers/parallel_sgd/parallel_sgd.hpp> | |
#include <mlpack/core/optimizers/parallel_sgd/sparse_mc_function.hpp> | |
using namespace std; | |
using namespace mlpack; | |
using namespace mlpack::optimization; | |
int main(int argc, char* argv[]) | |
{ | |
mlpack::CLI::ParseCommandLine(argc, argv); | |
arma::sp_mat ans_mat("5 3 0 1; 4 0 0 1; 1 1 0 5; 1 0 0 4; 0 1 5 4;"); | |
SparseMCLossFunction f(ans_mat, 0.02, 4); | |
ConstantStep decayPolicy(0.1); | |
ParallelSGD<ConstantStep> s(50000, 5, 1e-5, decayPolicy); | |
arma::mat iterate(4, 9, arma::fill::randn); | |
double initial = 0; | |
for(size_t i = 0; i < f.NumFunctions(); ++i) | |
{ | |
initial += f.Evaluate(iterate, i); | |
} | |
cout << "Initial objective : " << initial << endl; | |
s.Optimize(f, iterate); | |
//cout << iterate << endl; | |
arma::mat gen(arma::trans(iterate.cols(0, 4)) * iterate.cols(5, 8)); | |
cout << arma::mat(ans_mat) << endl; | |
cout << gen << endl; | |
double final = 0; | |
for(size_t i = 0; i < f.NumFunctions(); ++i) | |
{ | |
final += f.Evaluate(iterate, i); | |
} | |
cout << "Final objective : " << final << endl; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
############################################################# | |
Without mean | |
############################################################# | |
λ mlpack_spike/hogwild ∴ OMP_NUM_THREADS=1 ./a.out | |
Initial objective : 124.529 | |
5.0000 3.0000 0 1.0000 | |
4.0000 0 0 1.0000 | |
1.0000 1.0000 0 5.0000 | |
1.0000 0 0 4.0000 | |
0 1.0000 5.0000 4.0000 | |
4.9927 2.9825 0.7369 0.9968 | |
3.9875 1.1196 -1.5995 0.9973 | |
1.0016 0.9976 1.0712 4.9785 | |
1.0091 -0.7997 -0.0219 3.9911 | |
0.2612 0.9937 4.9890 3.9918 | |
Final objective : 0.287676 | |
############################################################# | |
With mean | |
############################################################# | |
λ mlpack_spike/hogwild ∴ OMP_NUM_THREADS=1 ./a.out | |
Initial objective : 30.5357 | |
5.0000 3.0000 0 1.0000 | |
4.0000 0 0 1.0000 | |
1.0000 1.0000 0 5.0000 | |
1.0000 0 0 4.0000 | |
0 1.0000 5.0000 4.0000 | |
2.2052 0.2323 1.3727 -1.7562 | |
1.2226 0.8786 -1.3400 -1.7594 | |
-1.7648 -1.7474 1.5708 2.2262 | |
-1.7624 -0.9664 -0.8608 1.2239 | |
-0.6352 -1.7405 2.2264 1.2243 | |
Final objective : 0.226394 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment