Last active
July 25, 2017 10:00
-
-
Save shikharbhardwaj/3f461947d2a572d6e9697e57f3493bc0 to your computer and use it in GitHub Desktop.
HOGWILD! runs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Run HOGWILD! on the Netflix dataset, with the regularized SVD function. | |
#include <mlpack/core.hpp> | |
#include <mlpack/core/optimizers/parallel_sgd/parallel_sgd.hpp> | |
#include <mlpack/core/optimizers/parallel_sgd/decay_policies/exponential_backoff.hpp> | |
#include <mlpack/methods/regularized_svd/regularized_svd_function.hpp> | |
using namespace std; | |
using namespace mlpack; | |
using namespace mlpack::optimization; | |
using namespace arma; | |
double MSE(const arma::mat& parameters, const arma::mat& data, size_t numUsers) { | |
Log::Info << parameters.n_cols << std::endl; | |
double cost = 0.0; | |
for(size_t i = 0; i < data.n_cols; ++i) { | |
const size_t user = data(0, i); | |
const size_t item = data(1, i) + numUsers; | |
// Calculate the squared error in the prediction. | |
const double rating = data(2, i); | |
double ratingError = rating - arma::dot(parameters.col(user), | |
parameters.col(item)); | |
double ratingErrorSquared = ratingError * ratingError; | |
cost += ratingErrorSquared; | |
} | |
return cost / data.n_cols; | |
} | |
int main(int argc, char* argv[]) { | |
mlpack::CLI::ParseCommandLine(argc, argv); | |
// Load the train and test datasets. | |
arma::mat train_dataset, test_dataset; | |
train_dataset.load("./netflix.svd.train.arm"); | |
test_dataset.load("./netflix.svd.test.arm"); | |
Log::Info << "Loaded data."; | |
// Paratmeters for the function. | |
size_t rank = 30; | |
double lambda = 0.01; | |
// Construct the regularized SVD function. | |
svd::RegularizedSVDFunction<arma::mat> f(std::move(train_dataset), rank, lambda); | |
Log::Info << "Functions : " << f.NumFunctions() << std::endl; | |
Log::Info << "Items : " << f.NumItems() << std::endl; | |
Log::Info << "Users: " << f.NumUsers() << std::endl; | |
// The exponential backoff stepsize decay. | |
ExponentialBackoff decay_policy(1, 0.001, 0.95); | |
ParallelSGD<ExponentialBackoff> optimizer(20, f.NumFunctions() / 4, 1e-5, | |
true, decay_policy); | |
arma::mat iterate = f.GetInitialPoint(); | |
Log::Info << "Initial test RMSE : " << sqrt(MSE(iterate, test_dataset, | |
f.NumUsers())) << std::endl; | |
mlpack::Timer::Start("optimize"); | |
optimizer.Optimize(f, iterate); | |
mlpack::Timer::Stop("optimize"); | |
Log::Info << "Final test RMSE : " << sqrt(MSE(iterate, test_dataset, | |
f.NumUsers())) << std::endl; | |
//arma::mat u = iterate.submat(0, f.NumUsers(), rank - 1, f.NumUsers() + | |
//f.NumItems() - 1).t(); | |
//arma::mat v = iterate.submat(0, 0, rank - 1, f.NumUsers() - 1); | |
iterate.save("hogwild_final.arm"); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
λ hogwild/mc ∴ ./a.out -v | |
[INFO ] Loaded data.Functions : 100198806 | |
[INFO ] Items : 17770 | |
[INFO ] Users: 480189 | |
[INFO ] 497959 | |
[INFO ] Initial test RMSE : 4.10353 | |
[INFO ] Parallel SGD: iteration 1, objective 1.82124e+09. | |
[INFO ] Parallel SGD: iteration 2, objective 1.10554e+08. | |
[INFO ] Parallel SGD: iteration 3, objective 1.03337e+08. | |
[INFO ] Parallel SGD: iteration 4, objective 1.00748e+08. | |
[INFO ] Parallel SGD: iteration 5, objective 9.94278e+07. | |
[INFO ] Parallel SGD: iteration 6, objective 9.85979e+07. | |
[INFO ] Parallel SGD: iteration 7, objective 9.80398e+07. | |
[INFO ] Parallel SGD: iteration 8, objective 9.76066e+07. | |
[INFO ] Parallel SGD: iteration 9, objective 9.72745e+07. | |
[INFO ] Parallel SGD: iteration 10, objective 9.69798e+07. | |
[INFO ] Parallel SGD: iteration 11, objective 9.67212e+07. | |
[INFO ] Parallel SGD: iteration 12, objective 9.646e+07. | |
[INFO ] Parallel SGD: iteration 13, objective 9.6214e+07. | |
[INFO ] Parallel SGD: iteration 14, objective 9.59647e+07. | |
[INFO ] Parallel SGD: iteration 15, objective 9.57138e+07. | |
[INFO ] Parallel SGD: iteration 16, objective 9.54541e+07. | |
[INFO ] Parallel SGD: iteration 17, objective 9.5187e+07. | |
[INFO ] Parallel SGD: iteration 18, objective 9.49255e+07. | |
[INFO ] Parallel SGD: iteration 19, objective 9.46596e+07. | |
[INFO ] | |
[INFO ] Parallel SGD terminated with objective : 9.46596e+07 | |
[INFO ] 497959 | |
[INFO ] Final test RMSE : 1.36406 | |
[INFO ] | |
[INFO ] Execution parameters: | |
[INFO ] help: 0 | |
[INFO ] info: | |
[INFO ] verbose: 1 | |
[INFO ] version: 0 | |
[INFO ] Program timers: | |
[INFO ] optimize: 713.535343s (11 mins, 53.5 secs) | |
[INFO ] total_time: 719.364519s (11 mins, 59.3 secs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* A comparative run with the dataset from the authors' implemenetation to get a | |
* comparative performance metric. | |
*/ | |
#include <mlpack/core.hpp> | |
#include <mlpack/core/optimizers/parallel_sgd/parallel_sgd.hpp> | |
#include <mlpack/core/optimizers/parallel_sgd/decay_policies/exponential_backoff.hpp> | |
#include <mlpack/methods/sparse_svm/sparse_svm_function.hpp> | |
#include <mlpack/core/optimizers/sgd/sgd.hpp> | |
using namespace std; | |
using namespace mlpack; | |
using namespace mlpack::optimization; | |
template <typename T> int sgn(T val) { | |
return (T(0) < val) - (val < T(0)); | |
} | |
float MeanSquaredError( const arma::sp_mat& test_set, | |
const arma::mat& labels, | |
const arma::mat& iterate) { | |
float MSE = 0.f; | |
for(size_t i = 0; i < test_set.n_cols; ++i) { | |
MSE += (sgn(arma::dot(iterate, test_set.col(i))) != | |
sgn(labels(i))); | |
} | |
MSE /= test_set.n_cols; | |
return MSE; | |
} | |
int main(int argc, char* argv[]) { | |
mlpack::CLI::ParseCommandLine(argc, argv); | |
SparseSVMFunction function; | |
// | |
// Load the data | |
function.Dataset().load("./RCV1_train_data.arm"); | |
function.Labels().load("./RCV1_train_labels.arm"); | |
arma::mat test_labels; | |
arma::sp_mat test_data; | |
test_labels.load("./RCV1_test_labels.arm"); | |
test_data.load("./RCV1_test_data.arm"); | |
ExponentialBackoff decayPolicy(15, 0.2, 0.8); | |
arma::mat iterate(function.Dataset().n_rows, 1, arma::fill::randu); | |
Log::Info << "Train RMSE : " | |
<< std::sqrt(MeanSquaredError(function.Dataset(), function.Labels(), iterate)) << std::endl; | |
Log::Info << "Test RMSE : " | |
<< std::sqrt(MeanSquaredError(test_data, test_labels, iterate)) << std::endl; | |
ParallelSGD<ExponentialBackoff> optimizer(20, function.NumFunctions() / 4, | |
1e-5, true, decayPolicy); | |
//mlpack::optimization::StandardSGD optimizer(0.2); | |
mlpack::Timer::Start("optimize"); | |
optimizer.Optimize(function, iterate); | |
mlpack::Timer::Stop("optimize"); | |
Log::Info << "Train RMSE : " | |
<< std::sqrt(MeanSquaredError(function.Dataset(), function.Labels(), iterate)) << std::endl; | |
Log::Info << "Test RMSE : " | |
<< std::sqrt(MeanSquaredError(test_data, test_labels, iterate)) << std::endl; | |
iterate.save("./svm_weights.arm"); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
λ hogwild/sparseSVM ∴ ./a.out -v | |
[INFO ] Train RMSE : 0.792003 | |
[INFO ] Test RMSE : 0.135361 | |
[INFO ] Parallel SGD: iteration 1, objective 2.34139e+06. | |
[INFO ] Parallel SGD: iteration 2, objective 86926.3. | |
[INFO ] Parallel SGD: iteration 3, objective 82015.3. | |
[INFO ] Parallel SGD: iteration 4, objective 79496.5. | |
[INFO ] Parallel SGD: iteration 5, objective 77981.8. | |
[INFO ] Parallel SGD: iteration 6, objective 77066.8. | |
[INFO ] Parallel SGD: iteration 7, objective 76093.6. | |
[INFO ] Parallel SGD: iteration 8, objective 75129.2. | |
[INFO ] Parallel SGD: iteration 9, objective 74762.1. | |
[INFO ] Parallel SGD: iteration 10, objective 74573.7. | |
[INFO ] Parallel SGD: iteration 11, objective 73774.7. | |
[INFO ] Parallel SGD: iteration 12, objective 73420. | |
[INFO ] Parallel SGD: iteration 13, objective 73182.3. | |
[INFO ] Parallel SGD: iteration 14, objective 72772.6. | |
[INFO ] Parallel SGD: iteration 15, objective 72972.6. | |
[INFO ] Parallel SGD: iteration 16, objective 71530.1. | |
[INFO ] Parallel SGD: iteration 17, objective 71437.5. | |
[INFO ] Parallel SGD: iteration 18, objective 71513. | |
[INFO ] Parallel SGD: iteration 19, objective 70952. | |
[INFO ] | |
[INFO ] Parallel SGD terminated with objective : 70952 | |
[INFO ] Train RMSE : 0.186585 | |
[INFO ] Test RMSE : 0.0359565 | |
[INFO ] | |
[INFO ] Execution parameters: | |
[INFO ] help: 0 | |
[INFO ] info: | |
[INFO ] verbose: 1 | |
[INFO ] version: 0 | |
[INFO ] Program timers: | |
[INFO ] optimize: 9.884369s | |
[INFO ] total_time: 10.483783s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment