Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Another approach for testing Matrix completion
#include <mlpack/core.hpp>
#include <mlpack/core/optimizers/parallel_sgd/parallel_sgd.hpp>
#include <mlpack/core/optimizers/parallel_sgd/decay_policies/constant_step.hpp>
#include <mlpack/core/optimizers/parallel_sgd/sparse_mc_function.hpp>
using namespace std;
using namespace mlpack::optimization;
int main()
{
size_t numRows, numCols;
numCols = numRows = 10; // A square matrix
size_t degree = 2;
arma::sp_mat inputMat(numRows, numCols);
// Alternating columns with +1 and -1, at 2 consecutive positions.
for(size_t i = 0; i < numCols; ++i)
{
int val = (i % 2) ? 1 : -1;
inputMat.at(i, i) = val;
inputMat.at((i + 2) % numRows, i) = val;
}
cout << arma::mat(inputMat) << endl;
SparseMCLossFunction f(inputMat, 1e-6, degree);
ConstantStep decayPolicy(0.2);
size_t batchSize = std::ceil((float) (2 * numRows) / omp_get_max_threads());
ParallelSGD<ConstantStep> s(1e5, batchSize,
1e-5, decayPolicy);
arma::mat iterate(degree, numCols + numRows);
double initial = 0;
for(size_t i = 0; i < f.NumFunctions(); ++i)
{
initial += f.Evaluate(iterate, i);
}
cout << "Initial objective : " << initial << endl;
s.Optimize(f, iterate);
double final = 0;
for(size_t i = 0; i < f.NumFunctions(); ++i)
{
final += f.Evaluate(iterate, i);
}
cout << "Final objective : " << final << endl;
arma::mat gen(arma::trans(iterate.cols(0, numRows - 1)) *
iterate.cols(numRows, numRows + numCols - 1));
cout << gen << endl;
}
#############################################################################################
i , i
i + 2, i (Same row evidence)
#############################################################################################
λ mlpack_spike/hogwild ∴ g++ sparse_mc_test.cpp -O2 -std=c++11 -Wall -larmadillo -lmlpack -fopenmp -g
λ mlpack_spike/hogwild ∴ ./a.out
-1.0000 0 0 0 0 0 0 0 -1.0000 0
0 1.0000 0 0 0 0 0 0 0 1.0000
-1.0000 0 -1.0000 0 0 0 0 0 0 0
0 1.0000 0 1.0000 0 0 0 0 0 0
0 0 -1.0000 0 -1.0000 0 0 0 0 0
0 0 0 1.0000 0 1.0000 0 0 0 0
0 0 0 0 -1.0000 0 -1.0000 0 0 0
0 0 0 0 0 1.0000 0 1.0000 0 0
0 0 0 0 0 0 -1.0000 0 -1.0000 0
0 0 0 0 0 0 0 1.0000 0 1.0000
Initial objective : 18
Final objective : 8.30448e-05
-0.9987 -0.9370 -0.9928 -0.9345 -0.9903 -0.9332 -0.9930 -0.9326 -0.9990 -0.9409
1.0650 0.9992 1.0587 0.9966 1.0560 0.9952 1.0589 0.9945 1.0654 1.0004
-1.0023 -0.9404 -0.9965 -0.9380 -0.9939 -0.9367 -0.9966 -0.9360 -1.0027 -0.9443
1.0672 1.0013 1.0609 0.9986 1.0582 0.9972 1.0611 0.9965 1.0676 1.0054
-1.0083 -0.9460 -1.0024 -0.9435 -0.9998 -0.9422 -1.0025 -0.9415 -1.0086 -0.9499
1.0695 1.0035 1.0633 1.0009 1.0606 0.9995 1.0635 0.9987 1.0700 1.0076
-1.0081 -0.9458 -1.0022 -0.9433 -0.9996 -0.9420 -1.0023 -0.9413 -1.0084 -0.9497
1.0707 1.0046 1.0645 1.0020 1.0618 1.0006 1.0646 0.9999 1.0712 1.0088
-1.0021 -0.9402 -0.9963 -0.9378 -0.9937 -0.9364 -0.9964 -0.9358 -1.0025 -0.9441
1.0711 1.0049 1.0648 1.0023 1.0621 1.0009 1.0650 1.0002 1.0715 0.9998
#############################################################################################
i , i
i + 1, i
i + 3, i
i + 5, i (Same column evidence)
#############################################################################################
λ mlpack_spike/hogwild ∴ ./a.out
-1.0000 0 0 0 0 1.0000 0 1.0000 0 1.0000
-1.0000 1.0000 0 0 0 0 -1.0000 0 -1.0000 0
0 1.0000 -1.0000 0 0 0 0 1.0000 0 1.0000
-1.0000 0 -1.0000 1.0000 0 0 0 0 -1.0000 0
0 1.0000 0 1.0000 -1.0000 0 0 0 0 1.0000
-1.0000 0 -1.0000 0 -1.0000 1.0000 0 0 0 0
0 1.0000 0 1.0000 0 1.0000 -1.0000 0 0 0
0 0 -1.0000 0 -1.0000 0 -1.0000 1.0000 0 0
0 0 0 1.0000 0 1.0000 0 1.0000 -1.0000 0
0 0 0 0 -1.0000 0 -1.0000 0 -1.0000 1.0000
Initial objective : 152
Final objective : 0.000190717
-1.0019 0.9986 -0.9992 1.0003 -0.9989 1.0012 -0.9992 0.9986 -1.0014 0.9984
-1.0010 0.9978 -0.9983 0.9994 -0.9980 1.0003 -0.9983 0.9977 -1.0005 0.9975
-1.0038 1.0005 -1.0011 1.0022 -1.0007 1.0031 -1.0011 1.0005 -1.0033 1.0003
-1.0006 0.9974 -0.9979 0.9990 -0.9976 0.9999 -0.9979 0.9973 -1.0001 0.9971
-1.0030 0.9998 -1.0003 1.0015 -1.0000 1.0024 -1.0003 0.9997 -1.0026 0.9996
-1.0009 0.9977 -0.9982 0.9994 -0.9979 1.0003 -0.9982 0.9976 -1.0005 0.9975
-1.0024 0.9991 -0.9997 1.0008 -0.9994 1.0017 -0.9997 0.9991 -1.0019 0.9989
-1.0034 1.0002 -1.0008 1.0019 -1.0004 1.0028 -1.0008 1.0002 -1.0030 1.0000
-1.0009 0.9976 -0.9982 0.9993 -0.9978 1.0002 -0.9982 0.9976 -1.0004 0.9974
-1.0028 0.9996 -1.0001 1.0012 -0.9998 1.0021 -1.0001 0.9995 -1.0023 0.9993
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment