Skip to content

Instantly share code, notes, and snippets.

@c4goldsw
Created March 25, 2016 14:33
Show Gist options
  • Save c4goldsw/b16e03b4f67364a37a89 to your computer and use it in GitHub Desktop.
Save c4goldsw/b16e03b4f67364a37a89 to your computer and use it in GitHub Desktop.
#include <shogun/base/init.h>
#include <shogun/base/some.h>
#include <shogun/lib/common.h>
#include <shogun/evaluation/MeanSquaredError.h>
#include <shogun/io/CSVFile.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/lib/WrappedObjectArray.h>
#include <shogun/io/SerializableAsciiFile.h>
#include <shogun/lib/SGVector.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/regression/LinearRidgeRegression.h>
using namespace shogun;
int main(int, char*[])
{
init_shogun_with_defaults();
auto f_feats_train = some<CCSVFile>("../../data/regression_1d_linear_features_train.dat");
auto f_feats_test = some<CCSVFile>("../../data/regression_1d_linear_features_test.dat");
auto f_labels_train = some<CCSVFile>("../../data/regression_1d_linear_labels_train.dat");
auto f_labels_test = some<CCSVFile>("../../data/regression_1d_linear_labels_test.dat");
//![create_features]
auto features_train = some<CDenseFeatures<float64_t>>(f_feats_train);
auto features_test = some<CDenseFeatures<float64_t>>(f_feats_test);
auto labels_train = some<CRegressionLabels>(f_labels_train);
auto labels_test = some<CRegressionLabels>(f_labels_test);
//![create_features]
//![create_instance]
auto tau = 0.001;
auto lrr = some<CLinearRidgeRegression>(tau, features_train, labels_train);
//![create_instance]
//![train_and_apply]
lrr->train();
auto labels_predict = lrr->apply_regression(features_test);
auto bias = lrr->get_bias();
//![train_and_apply]
//[!disable_bias]
lrr->set_compute_bias(false);
//[!disable_bias]
//[!set_bias_manually]
lrr->set_bias(bias);
//[!set_bias_manually]
//[!extract_w]
auto weights = lrr->get_w();
//[!extract_w]
//![evaluate_error]
auto eval = some<CMeanSquaredError>();
auto mse = eval->evaluate(labels_predict, labels_test);
//![evaluate_error]
// integration testing variables
auto output = labels_test->get_labels();
auto __sg_storage = some<CWrappedObjectArray>();
auto __sg_storage_file = some<CSerializableAsciiFile>("linear_ridge_regression.dat", 119);
__sg_storage->append_wrapped(tau, "tau");
__sg_storage->append_wrapped(bias, "bias");
__sg_storage->append_wrapped(weights, "weights");
__sg_storage->append_wrapped(mse, "mse");
__sg_storage->append_wrapped(output, "output");
__sg_storage->save_serializable(__sg_storage_file);
exit_shogun();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment