Skip to content

Instantly share code, notes, and snippets.

@shikharbhardwaj
Created May 27, 2017 08:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shikharbhardwaj/c070d704cbe6a7100254d280eac2edad to your computer and use it in GitHub Desktop.
Save shikharbhardwaj/c070d704cbe6a7100254d280eac2edad to your computer and use it in GitHub Desktop.
NBC benchmark
#include <mlpack/core.hpp>
#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
using namespace std;
using namespace mlpack;
using namespace naive_bayes;
int main(int argc, char *argv[]) {
CLI::ParseCommandLine(argc, argv);
const string trainFilename = "adult_data_rev.csv";
const string testFilename = "adult_test_rev.csv";
size_t classes = 2;
arma::mat trainData, trainRes, calcMat;
// Transpose
data::Load(trainFilename, trainData, true);
// Extract labels
arma::Row<size_t> labels(trainData.n_cols);
for (size_t i = 0; i < trainData.n_cols; ++i) {
labels[i] = trainData(trainData.n_rows - 1, i);
}
// Remove the labels from the data
trainData.shed_row(trainData.n_rows - 1);
Timer::Start("train_time");
NaiveBayesClassifier<> nbc(trainData, labels, classes);
Timer::Stop("train_time");
// Test the model
arma::mat testData;
data::Load(testFilename, testData, true);
arma::Row<size_t> testLabels(testData.n_cols);
arma::Row<size_t> calcVec;
arma::mat calcProbs;
for (size_t i = 0; i < testData.n_cols; ++i) {
testLabels[i] = testData(testData.n_rows - 1, i);
}
testData.shed_row(testData.n_rows - 1);
nbc.Classify(testData, calcVec, calcProbs);
size_t truePositives = 0, falsePositives = 0, falseNegatives = 0;
for (size_t i = 0; i < testData.n_cols; ++i) {
if (calcVec[i]) {
if (testLabels[i]) {
++truePositives;
} else {
++falsePositives;
}
} else {
if (testLabels[i]) {
++falseNegatives;
}
}
}
Log::Info << "Precision : "
<< (float)truePositives / (truePositives + falsePositives)
<< endl;
Log::Info << "Recall : "
<< (float)truePositives / (truePositives + falseNegatives)
<< endl;
}
λ nbc/adult ∴ ./a.out -v
[DEBUG] Compiled with debugging symbols.
[INFO ] Loading 'adult_data_rev_noimpute.csv' as CSV data. Size is 15 x 32560.
[INFO ] Loading 'adult_test_rev_noimpute.csv' as CSV data. Size is 15 x 16280.
[INFO ] Precision : 0.671941
[INFO ] Recall : 0.331253
[INFO ]
[INFO ] Execution parameters:
[INFO ] help: 0
[INFO ] info:
[INFO ] verbose: 1
[INFO ] version: 0
[INFO ] Program timers:
[INFO ] loading_data: 0.197992s
[INFO ] total_time: 0.255008s
[INFO ] train_time: 0.011908s
λ nbc/adult ∴ ./a.out -v
[DEBUG] Compiled with debugging symbols.
[INFO ] Loading 'adult_data_rev.csv' as CSV data. Size is 15 x 32560.
[INFO ] Loading 'adult_test_rev.csv' as CSV data. Size is 15 x 16281.
[INFO ] Precision : 0.67449
[INFO ] Recall : 0.343734
[INFO ]
[INFO ] Execution parameters:
[INFO ] help: 0
[INFO ] info:
[INFO ] verbose: 1
[INFO ] version: 0
[INFO ] Program timers:
[INFO ] loading_data: 0.196525s
[INFO ] total_time: 0.252860s
[INFO ] train_time: 0.011893s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment