Created
May 27, 2017 08:16
-
-
Save shikharbhardwaj/c070d704cbe6a7100254d280eac2edad to your computer and use it in GitHub Desktop.
NBC benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <mlpack/core.hpp> | |
#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp> | |
using namespace std; | |
using namespace mlpack; | |
using namespace naive_bayes; | |
int main(int argc, char *argv[]) { | |
CLI::ParseCommandLine(argc, argv); | |
const string trainFilename = "adult_data_rev.csv"; | |
const string testFilename = "adult_test_rev.csv"; | |
size_t classes = 2; | |
arma::mat trainData, trainRes, calcMat; | |
// Transpose | |
data::Load(trainFilename, trainData, true); | |
// Extract labels | |
arma::Row<size_t> labels(trainData.n_cols); | |
for (size_t i = 0; i < trainData.n_cols; ++i) { | |
labels[i] = trainData(trainData.n_rows - 1, i); | |
} | |
// Remove the labels from the data | |
trainData.shed_row(trainData.n_rows - 1); | |
Timer::Start("train_time"); | |
NaiveBayesClassifier<> nbc(trainData, labels, classes); | |
Timer::Stop("train_time"); | |
// Test the model | |
arma::mat testData; | |
data::Load(testFilename, testData, true); | |
arma::Row<size_t> testLabels(testData.n_cols); | |
arma::Row<size_t> calcVec; | |
arma::mat calcProbs; | |
for (size_t i = 0; i < testData.n_cols; ++i) { | |
testLabels[i] = testData(testData.n_rows - 1, i); | |
} | |
testData.shed_row(testData.n_rows - 1); | |
nbc.Classify(testData, calcVec, calcProbs); | |
size_t truePositives = 0, falsePositives = 0, falseNegatives = 0; | |
for (size_t i = 0; i < testData.n_cols; ++i) { | |
if (calcVec[i]) { | |
if (testLabels[i]) { | |
++truePositives; | |
} else { | |
++falsePositives; | |
} | |
} else { | |
if (testLabels[i]) { | |
++falseNegatives; | |
} | |
} | |
} | |
Log::Info << "Precision : " | |
<< (float)truePositives / (truePositives + falsePositives) | |
<< endl; | |
Log::Info << "Recall : " | |
<< (float)truePositives / (truePositives + falseNegatives) | |
<< endl; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
λ nbc/adult ∴ ./a.out -v | |
[DEBUG] Compiled with debugging symbols. | |
[INFO ] Loading 'adult_data_rev_noimpute.csv' as CSV data. Size is 15 x 32560. | |
[INFO ] Loading 'adult_test_rev_noimpute.csv' as CSV data. Size is 15 x 16280. | |
[INFO ] Precision : 0.671941 | |
[INFO ] Recall : 0.331253 | |
[INFO ] | |
[INFO ] Execution parameters: | |
[INFO ] help: 0 | |
[INFO ] info: | |
[INFO ] verbose: 1 | |
[INFO ] version: 0 | |
[INFO ] Program timers: | |
[INFO ] loading_data: 0.197992s | |
[INFO ] total_time: 0.255008s | |
[INFO ] train_time: 0.011908s | |
λ nbc/adult ∴ ./a.out -v | |
[DEBUG] Compiled with debugging symbols. | |
[INFO ] Loading 'adult_data_rev.csv' as CSV data. Size is 15 x 32560. | |
[INFO ] Loading 'adult_test_rev.csv' as CSV data. Size is 15 x 16281. | |
[INFO ] Precision : 0.67449 | |
[INFO ] Recall : 0.343734 | |
[INFO ] | |
[INFO ] Execution parameters: | |
[INFO ] help: 0 | |
[INFO ] info: | |
[INFO ] verbose: 1 | |
[INFO ] version: 0 | |
[INFO ] Program timers: | |
[INFO ] loading_data: 0.196525s | |
[INFO ] total_time: 0.252860s | |
[INFO ] train_time: 0.011893s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment