Last active
March 18, 2017 12:42
-
-
Save MikeLing/50c96b402b0f8a42a5355ae962fb4fd0 to your computer and use it in GitHub Desktop.
rewrite test for SVM by Gtest fixture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <shogun/classifier/svm/SVMLight.h> | |
#include <shogun/features/DenseFeatures.h> | |
#include <shogun/kernel/GaussianKernel.h> | |
#include <gtest/gtest.h> | |
#include "../../utils/MockDataForSVM.h" | |
using namespace shogun; | |
#ifdef USE_GPL_SHOGUN | |
#ifdef HAVE_LAPACK | |
class SVMLightFixture: public ::testing::test { | |
public: | |
CLabels* pred; | |
SVMLightFixture( ) { | |
// init a kernel for SVM traning | |
CGaussianKernel* gauss_kernel = new CGaussianKernel(features_train, features_train, 15); | |
CSVMLight* svml = new CSVMLight(C, gauss_kernel, ground_truth); | |
svml->set_epsilon(epsilon); | |
svml->train(); | |
CLabels* pred = svml->apply(features_test); | |
} | |
void SetUp( ) { | |
index_t num_samples = 100; | |
auto C = 1.0; | |
auto epsilon = 0.001; | |
CMockDataForSVM mockData = CMockDataForSVM(); | |
mockData.generate_data(num_samples); | |
CDenseFeatures<float64_t>* features_train = mockData.get_features_train(); | |
CDenseFeatures<float64_t>* features_test = mockData.get_features_test(); | |
CBinaryLabels* ground_truth = mockData.get_labels_test(); | |
} | |
void TearDown( ) { | |
SG_UNREF(svml); | |
SG_UNREF(features_train); | |
SG_UNREF(features_test); | |
SG_UNREF(pred); | |
SG_UNREF(ground_truth); | |
} | |
~SVMLightFixture( ) { | |
} | |
protected: | |
CDenseFeatures<float64_t>* features_train; | |
CDenseFeatures<float64_t>* features_test; | |
CBinaryLabels* ground_truth; | |
}; | |
TEST(SVMLightFixture, train) { | |
for (int i = 0; i < num_samples; ++i) | |
EXPECT_EQ(ground_truth->get_int_label(i), ((CBinaryLabels*)pred)->get_int_label(i)); | |
} | |
#endif // HAVE_LAPACK | |
#endif //USE_GPL_SHOGUN | |
#include <shogun/labels/BinaryLabels.h> | |
#include <shogun/features/DenseFeatures.h> | |
#include <shogun/features/DataGenerator.h> | |
using namespace shogun; | |
class CMockDataForSVM | |
{ | |
public: | |
CMockDataForSVM() {/*Nothing to do in here*/} | |
~CMockDataForSVM() {/*Nothing to do in here*/} | |
/** | |
* This function is about to create features simpled from (2) gaussians distribution | |
* | |
* @num_samples the number of samples | |
*/ | |
void generate_data(const index_t num_samples) | |
{ | |
CMath::init_random(5); | |
SGMatrix<float64_t> data = | |
CDataGenerator::generate_gaussians(num_samples, 2, 2); | |
CDenseFeatures<float64_t> features(data); | |
SGVector<index_t> train_idx(num_samples), test_idx(num_samples); | |
SGVector<float64_t> labels(num_samples); | |
for (index_t i = 0, j = 0; i < data.num_cols; ++i) | |
{ | |
if (i % 2 == 0) | |
train_idx[j] = i; | |
else | |
test_idx[j++] = i; | |
labels[i/2] = (i < data.num_cols/2) ? 1.0 : -1.0; | |
} | |
features_train = (CDenseFeatures<float64_t>*)features.copy_subset(train_idx); | |
features_test = (CDenseFeatures<float64_t>*)features.copy_subset(test_idx); | |
CBinaryLabels temp_labels = CBinaryLabels(labels); | |
labels_train = (CBinaryLabels*)temp_labels.clone(); | |
labels_test = (CBinaryLabels*)temp_labels.clone(); | |
} | |
/* get the traning features */ | |
CDenseFeatures<float64_t>* get_features_train() | |
{ | |
return features_train; | |
} | |
/* get the test features */ | |
CDenseFeatures<float64_t>* get_features_test() | |
{ | |
return features_test; | |
} | |
/* get the test labels */ | |
CBinaryLabels* get_labels_train() | |
{ | |
return labels_train; | |
} | |
/* get the traning labels */ | |
CBinaryLabels* get_labels_test() | |
{ | |
return labels_test; | |
} | |
protected: | |
// data for training | |
CDenseFeatures<float64_t>* features_train; | |
// data for testing | |
CDenseFeatures<float64_t>* features_test; | |
// traning label | |
CBinaryLabels* labels_train; | |
// testing label | |
CBinaryLabels* labels_test; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment