Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Sparse SVM run on RCV1 dataset
λ mlpack_spike/hogwild ∴ g++ svm_main.cpp -O2 -std=c++11 -Wall -larmadillo -lmlpack -fopenmp
λ mlpack_spike/hogwild ∴ ./a.out
RMSE : 0.797936
Initial loss : 70176.4
Final loss : 114.332
RMSE : 0.050055
λ mlpack_spike/hogwild ∴ ./a.out
RMSE : 0.797936
Initial loss : 70176.4
Final loss : 114.369
RMSE : 0.0496217
# HOGWILD!
λ mlpack_spike/hogwild ∴ time OMP_NUM_THREADS=4 ./a.out
RMSE : 0.797936
Initial loss : 70176.4
Final loss : 114.276
RMSE : 0.0496217
OMP_NUM_THREADS=4 ./a.out 16.12s user 0.18s system 392% cpu 4.155 total
# StandardSGD
λ mlpack_spike/hogwild ∴ time ./a.out
RMSE : 0.797936
Initial loss : 70176.4
Final loss : 5969.06
RMSE : 0.227869
./a.out 7.53s user 1.44s system 102% cpu 8.763 total
#include <mlpack/core.hpp>
#include <mlpack/core/optimizers/parallel_sgd/parallel_sgd.hpp>
#include <mlpack/core/optimizers/parallel_sgd/decay_policies/exponential_backoff.hpp>
#include <mlpack/core/optimizers/parallel_sgd/sparse_svm_function.hpp>
using namespace mlpack;
using namespace mlpack::optimization;
template <typename T> int sgn(T val) {
return (T(0) < val) - (val < T(0));
}
template <typename FunctionType>
float MeanSquaredError(FunctionType& function, arma::mat& iterate) {
float MSE = 0.f;
for(size_t i = 0; i < function.NumFunctions(); ++i) {
MSE += (sgn(arma::dot(iterate, function.Dataset().col(i))) != sgn(function.Labels()(i)));
}
MSE /= function.NumFunctions();
return MSE;
}
int main(int argc, char* argv[]) {
mlpack::CLI::ParseCommandLine(argc, argv);
SparseSVMLossFunction function;
// Load the data
function.Dataset().load("./RCV1_train_data.arm");
function.Labels().load("./RCV1_train_labels.arm");
ExponentialBackoff decayPolicy(15, 0.2, 0.8);
arma::mat iterate(function.Dataset().n_rows, 1, arma::fill::randu);
std::cout << "RMSE : " << std::sqrt(MeanSquaredError(function , iterate)) << std::endl;
float initialLoss = 0.f;
for(size_t i = 0; i < function.NumFunctions(); ++i){
initialLoss += function.Evaluate(iterate, i);
}
std::cout << "Initial loss : " << initialLoss << std::endl;
ParallelSGD<ExponentialBackoff> optimizer(200, function.NumFunctions() / 4, 1e-5, decayPolicy);
std::cout << "Final loss : " << optimizer.Optimize(function, iterate) << std::endl;
std::cout << "RMSE : " << std::sqrt(MeanSquaredError(function , iterate)) << std::endl;
iterate.save("./svm_weights.arm");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment