Skip to content

Instantly share code, notes, and snippets.

@geektoni
Last active July 28, 2017 15:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save geektoni/6b3bd3aafe70fbe477db485faa9cfe74 to your computer and use it in GitHub Desktop.
Save geektoni/6b3bd3aafe70fbe477db485faa9cfe74 to your computer and use it in GitHub Desktop.
#include <shogun/base/init.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/kernel/LinearKernel.h>
#include <shogun/regression/KernelRidgeRegression.h>
#include <shogun/evaluation/CrossValidation.h>
#include <shogun/evaluation/CrossValidationSplitting.h>
#include <shogun/evaluation/MeanSquaredError.h>
#include <shogun/lib/parameter_observers/ParameterObserverCV.h>
using namespace shogun;
int main(int argc, char ** argv)
{
init_shogun_with_defaults();
// data matrix dimensions
index_t num_vectors=100;
index_t num_features=1;
// training label data
SGVector<float64_t> lab(num_vectors);
// fill data matrix and labels
SGMatrix<float64_t> train_dat(num_features, num_vectors);
SGVector<float64_t>::range_fill_vector(train_dat.matrix, num_vectors);
for (index_t i=0; i<num_vectors; ++i)
{
// labels are linear plus noise
lab.vector[i]=i+CMath::normal_random(0, 1.0);
}
// training features
CDenseFeatures<float64_t>* features = new CDenseFeatures<float64_t>(train_dat);
SG_REF(features);
// training labels
CRegressionLabels* labels=new CRegressionLabels(lab);
// kernel
CLinearKernel* kernel=new CLinearKernel();
kernel->init(features, features);
// kernel ridge regression
float64_t tau=0.0001;
CKernelRidgeRegression* krr=new CKernelRidgeRegression(tau, kernel, labels);
// evaluation criterion
CMeanSquaredError* eval_crit= new CMeanSquaredError();
// train and output
krr->train(features);
// splitting strategy
index_t n_folds=5;
CCrossValidationSplitting* splitting=
new CCrossValidationSplitting(labels, n_folds);
// cross validation instance, 100 runs, 95% confidence interval
CCrossValidation* cross=new CCrossValidation(krr, features, labels,
splitting, eval_crit);
cross->set_num_runs(100);
// Create the parameter observer
// By setting false, we disable the observer verbosity
ParameterObserverCV par {false};
cross->subscribe_to_parameters(&par);
// We get all the observations catched
auto obs = par.get_observations();
for (auto o : obs)
{
// For each of the observations folds we print the
// train indices used.
for (auto fold : o->get_folds_result())
fold->get_train_indices().display_vector("Train indices ");
}
// clean up
SG_UNREF(cross);
SG_UNREF(features);
exit_shogun();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment