Created
August 10, 2016 12:37
-
-
Save Saurabh7/cb66deb288de81c19599af465082562f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <shogun/base/init.h> | |
#include <shogun/base/some.h> | |
#include <shogun/labels/MulticlassLabels.h> | |
#include <shogun/lib/SGVector.h> | |
#include <shogun/io/SerializableAsciiFile.h> | |
#include <shogun/lib/external/falconn/wrapper/cpp_wrapper_impl.h> | |
#include <shogun/features/DenseFeatures.h> | |
#include <shogun/io/CSVFile.h> | |
#include <shogun/evaluation/MulticlassAccuracy.h> | |
#include <shogun/lib/WrappedObjectArray.h> | |
#include <shogun/multiclass/KNN.h> | |
using namespace shogun; | |
int main(int, char*[]) | |
{ | |
init_shogun_with_defaults(); | |
auto f_feats_train = some<CCSVFile>("../../data/classifier_4class_2d_linear_features_train.dat"); | |
auto f_feats_test = some<CCSVFile>("../../data/classifier_4class_2d_linear_features_test.dat"); | |
auto f_labels_train = some<CCSVFile>("../../data/classifier_4class_2d_linear_labels_train.dat"); | |
auto f_labels_test = some<CCSVFile>("../../data/classifier_4class_2d_linear_labels_test.dat"); | |
//![create_features] | |
auto features_train = some<CDenseFeatures<float64_t>>(f_feats_train); | |
auto features_test = some<CDenseFeatures<float64_t>>(f_feats_test); | |
auto labels_train = some<CMulticlassLabels>(f_labels_train); | |
auto labels_test = some<CMulticlassLabels>(f_labels_test); | |
//![create_features] | |
//![choose_distance] | |
auto distance = some<CEuclideanDistance>(features_train, features_train); | |
//![choose_distance] | |
//![create_instance] | |
auto k = 3; | |
auto knn = some<CKNN>(k, distance, labels_train); | |
//![create_instance] | |
//![train_and_apply] | |
knn->train(); | |
auto labels_predict = knn->apply_multiclass(features_test); | |
//![train_and_apply] | |
//![evaluate_accuracy] | |
auto eval = some<CMulticlassAccuracy>(); | |
auto accuracy = eval->evaluate(labels_predict, labels_test); | |
//![evaluate_accuracy] | |
// additional integration testing variables | |
auto output = labels_predict->get_labels(); | |
// Serialize output for integration testing (automatically generated) | |
auto __sg_storage = some<CWrappedObjectArray>(); | |
auto __sg_storage_file = some<CSerializableAsciiFile>("knn.dat", 119); | |
__sg_storage->append_wrapped(k, "k"); | |
__sg_storage->append_wrapped(accuracy, "accuracy"); | |
__sg_storage->append_wrapped(output, "output"); | |
__sg_storage->save_serializable(__sg_storage_file); | |
exit_shogun(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment