Created
April 5, 2019 11:07
-
-
Save gf712/5bc697cdd19a32648fbdd12529673922 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <shogun/base/init.h> | |
#include <shogun/base/some.h> | |
#include <shogun/features/Features.h> | |
#include <shogun/io/File.h> | |
#include <shogun/io/SerializableAsciiFile.h> | |
#include <shogun/kernel/Kernel.h> | |
#include <shogun/labels/Labels.h> | |
#include <shogun/lib/DynamicObjectArray.h> | |
#include <shogun/lib/SGVector.h> | |
#include <shogun/machine/Machine.h> | |
#include <shogun/modelselection/NewGridSearch.h> | |
#include <shogun/util/factory.h> | |
using namespace shogun; | |
int main(int, char*[]) | |
{ | |
init_shogun_with_defaults(); | |
auto f_feats_train = wrap(csv_file("../../data/classifier_binary_2d_nonlinear_features_train.dat")); | |
auto f_feats_test = wrap(csv_file("../../data/classifier_binary_2d_nonlinear_features_test.dat")); | |
auto f_labels_train = wrap(csv_file("../../data/classifier_binary_2d_nonlinear_labels_train.dat")); | |
auto f_labels_test = wrap(csv_file("../../data/classifier_binary_2d_nonlinear_labels_test.dat")); | |
//![create_features] | |
auto features_train = wrap(features(f_feats_train)); | |
auto features_test = wrap(features(f_feats_test)); | |
auto labels_train = wrap(labels(f_labels_train)); | |
auto labels_test = wrap(labels(f_labels_test)); | |
//![create_features] | |
//![create_kernel] | |
auto k1 = wrap(kernel("GaussianKernel")); | |
auto k2 = wrap(kernel("PolyKernel")); | |
//![create_kernel] | |
//![create_machine] | |
auto svm = wrap(machine("LibSVM")); | |
svm->put("epsilon", 0.001); | |
svm->put("labels", labels_train); | |
//![create_machine] | |
//![create_master_node] | |
auto node = std::make_shared<GridSearch>(svm); | |
auto c1_param = SGVector<float64_t>(3); | |
c1_param[0] = 0.1; | |
c1_param[1] = 1; | |
c1_param[2] = 10; | |
auto c2_param = SGVector<float64_t>(3); | |
c2_param[0] = 0.1; | |
c2_param[1] = 1; | |
c2_param[2] = 10; | |
node->attach("C1", c1_param); | |
node->attach("C2", c2_param); | |
//![create_master_node] | |
//![create_child_nodes] | |
auto k1_node = std::make_shared<GridParameters>(k1); | |
auto k2_node = std::make_shared<GridParameters>(k2); | |
//![create_child_nodes] | |
//![attach_child_nodes] | |
auto degree_param = SGVector<int32_t>(3); | |
degree_param[0] = 1; | |
degree_param[1] = 2; | |
degree_param[2] = 3; | |
k2_node->attach("degree", degree_param); | |
node->attach("kernel", k1_node); | |
auto width_param = SGVector<float64_t>(3); | |
width_param[0] = 0.1; | |
width_param[1] = 1; | |
width_param[2] = 2; | |
node->attach("kernel::log_width", width_param); | |
node->attach("kernel", k2_node); | |
//![attach_child_nodes] | |
//![train_and_apply] | |
node->train(features_train); | |
//![train_and_apply] | |
// Serialize output for integration testing (automatically generated) | |
auto sg_storage = some<CDynamicObjectArray>(); | |
auto sg_storage_file = some<CSerializableAsciiFile>("grid_search.dat", 'w'); | |
sg_storage->append_element(c1_param, "c1_param"); | |
sg_storage->append_element(c2_param, "c2_param"); | |
sg_storage->append_element(width_param, "width_param"); | |
sg_storage->save_serializable(sg_storage_file); | |
exit_shogun(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment