Skip to content

Instantly share code, notes, and snippets.

@Saurabh7
Created August 10, 2016 12:37
Show Gist options
  • Save Saurabh7/cb66deb288de81c19599af465082562f to your computer and use it in GitHub Desktop.
Save Saurabh7/cb66deb288de81c19599af465082562f to your computer and use it in GitHub Desktop.
#include <shogun/base/init.h>
#include <shogun/base/some.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/lib/SGVector.h>
#include <shogun/io/SerializableAsciiFile.h>
#include <shogun/lib/external/falconn/wrapper/cpp_wrapper_impl.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/io/CSVFile.h>
#include <shogun/evaluation/MulticlassAccuracy.h>
#include <shogun/lib/WrappedObjectArray.h>
#include <shogun/multiclass/KNN.h>
using namespace shogun;
int main(int, char*[])
{
init_shogun_with_defaults();
auto f_feats_train = some<CCSVFile>("../../data/classifier_4class_2d_linear_features_train.dat");
auto f_feats_test = some<CCSVFile>("../../data/classifier_4class_2d_linear_features_test.dat");
auto f_labels_train = some<CCSVFile>("../../data/classifier_4class_2d_linear_labels_train.dat");
auto f_labels_test = some<CCSVFile>("../../data/classifier_4class_2d_linear_labels_test.dat");
//![create_features]
auto features_train = some<CDenseFeatures<float64_t>>(f_feats_train);
auto features_test = some<CDenseFeatures<float64_t>>(f_feats_test);
auto labels_train = some<CMulticlassLabels>(f_labels_train);
auto labels_test = some<CMulticlassLabels>(f_labels_test);
//![create_features]
//![choose_distance]
auto distance = some<CEuclideanDistance>(features_train, features_train);
//![choose_distance]
//![create_instance]
auto k = 3;
auto knn = some<CKNN>(k, distance, labels_train);
//![create_instance]
//![train_and_apply]
knn->train();
auto labels_predict = knn->apply_multiclass(features_test);
//![train_and_apply]
//![evaluate_accuracy]
auto eval = some<CMulticlassAccuracy>();
auto accuracy = eval->evaluate(labels_predict, labels_test);
//![evaluate_accuracy]
// additional integration testing variables
auto output = labels_predict->get_labels();
// Serialize output for integration testing (automatically generated)
auto __sg_storage = some<CWrappedObjectArray>();
auto __sg_storage_file = some<CSerializableAsciiFile>("knn.dat", 119);
__sg_storage->append_wrapped(k, "k");
__sg_storage->append_wrapped(accuracy, "accuracy");
__sg_storage->append_wrapped(output, "output");
__sg_storage->save_serializable(__sg_storage_file);
exit_shogun();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment