Skip to content

Instantly share code, notes, and snippets.

@sonney2k
Last active December 20, 2015 09:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sonney2k/6106780 to your computer and use it in GitHub Desktop.
Save sonney2k/6106780 to your computer and use it in GitHub Desktop.
#include <shogun/base/init.h>
#include <shogun/classifier/svm/SVMOcas.h>
#include <shogun/features/HashedDocDotFeatures.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/lib/NGramTokenizer.h>
#include <shogun/classifier/svm/SVMLight.h>
#include <shogun/lib/DelimiterTokenizer.h>
#include <shogun/io/LineReader.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/classifier/svm/LibLinear.h>
#include <shogun/preprocessor/SortUlongString.h>
#include <shogun/kernel/string/CommUlongStringKernel.h>
#include <shogun/evaluation/PRCEvaluation.h>
#include <shogun/evaluation/ROCEvaluation.h>
#include <cstdio>
#include <fstream>
#include <iostream>
#include <ctime>
using namespace shogun;
CBinaryLabels* load_labels(int32_t number_of_examples)
{
SGVector<int32_t> l_vector(number_of_examples);
const char* labels_filepath = "/home/user/data/webspam_train.lab";
std::ifstream lin(labels_filepath);
for (index_t doc_id=0; doc_id<number_of_examples; doc_id++)
lin >> l_vector.vector[doc_id];
lin.close();
return new CBinaryLabels(l_vector);
}
CBinaryLabels* load_test_labels()
{
SGVector<int32_t> l_vector(500);
const char* labels_filepath = "/home/user/data/webspam_train.lab";
std::ifstream lin(labels_filepath);
int32_t line;
for (index_t i=0; i<200000; i++)
lin >> line;
for (index_t doc_id=0; doc_id<500; doc_id++)
lin >> l_vector.vector[doc_id];
lin.close();
return new CBinaryLabels(l_vector);
}
CStringFeatures<char>* load_test_data()
{
SGString<char>* strings = SG_MALLOC(SGString<char>, 500);
const char* train_filepath = "/home/user/data/webspam_train.dat";
const int32_t doc_size = 5 * 1024 * 1024;
FILE* fin=fopen(train_filepath, "rb");
CDelimiterTokenizer* tzer = new CDelimiterTokenizer();
tzer->delimiters['\0'] = 1;
CLineReader* reader=new CLineReader(doc_size, fin, tzer);
for (index_t i=0; i<200000; i++)
{
if (i%10000==0)
SG_SPRINT("Skipped %d\n", i);
reader->read_line();
}
SGVector<char> tmp_string;
for (index_t doc_id=0; doc_id<500 && reader->has_next(); doc_id++)
{
tmp_string = reader->read_line();
strings[doc_id] = SGString<char>(tmp_string.vlen);
memcpy(strings[doc_id].string, tmp_string.vector, tmp_string.vlen);
}
fclose(fin);
SG_UNREF(tzer);
SG_UNREF(reader);
SGStringList<char> str_list(strings, 500, doc_size);
return new CStringFeatures<char>(str_list, RAWBYTE);
}
CStringFeatures<char>* load_data(int32_t number_of_examples)
{
SGString<char>* strings = SG_MALLOC(SGString<char>, number_of_examples);
const char* train_filepath = "/home/user/data/webspam_train.dat";
const int32_t doc_size = 5 * 1024 * 1024;
FILE* fin=fopen(train_filepath, "rb");
CDelimiterTokenizer* tzer = new CDelimiterTokenizer();
tzer->delimiters['\0'] = 1;
CLineReader* reader=new CLineReader(doc_size, fin, tzer);
SGVector<char> tmp_string;
for (index_t doc_id=0; doc_id<number_of_examples && reader->has_next(); doc_id++)
{
tmp_string = reader->read_line();
strings[doc_id] = SGString<char>(tmp_string.vlen);
memcpy(strings[doc_id].string, tmp_string.vector, tmp_string.vlen);
}
fclose(fin);
SG_UNREF(tzer);
SG_UNREF(reader);
SGStringList<char> str_list(strings, number_of_examples, doc_size);
return new CStringFeatures<char>(str_list, RAWBYTE);
}
int main(int argv, char** argc)
{
init_shogun_with_defaults();
int32_t number_of_examples[] = {50, 10000, 50000, 100000};
CNGramTokenizer* tzer = new CNGramTokenizer(8);
SG_SPRINT("Loading test data\n");
CStringFeatures<char>* test_string_feats = load_test_data();
SG_SPRINT("Test data loaded\n");
CStringFeatures<uint64_t>* test_feats = new CStringFeatures<uint64_t>(test_string_feats->get_alphabet());
test_feats->obtain_from_char(test_string_feats, 8-1, 8, 0, false);
CSortUlongString* preproc = new CSortUlongString();
preproc->init(test_feats);
test_feats->add_preprocessor(preproc);
test_feats->apply_preprocessor(); // don't do on-the-fly preprocessing: in-place preprocess the data now
CBinaryLabels* test_labels = load_test_labels();
for (index_t i=0; i<4; i++)
{
int32_t n = number_of_examples[i];
SG_SPRINT("Reading data...\n");
CStringFeatures<char>* string_feats = load_data(n);
CBinaryLabels* labels = load_labels(n);
SG_SPRINT("Data loaded.\n");
CStringFeatures<uint64_t>* feats = new CStringFeatures<uint64_t>(string_feats->get_alphabet());
feats->obtain_from_char(string_feats, 8-1, 8, 0, false);
CSortUlongString* preproc2 = new CSortUlongString();
preproc2->init(feats);
feats->add_preprocessor(preproc2);
feats->apply_preprocessor();
CCommUlongStringKernel* kernel = new CCommUlongStringKernel(feats,feats,false);
CSVM* svm=new CSVMLight(1, kernel, labels);
svm->set_epsilon(0.01);
clock_t t = clock();
SG_REF(labels);
SG_SPRINT("Start training\n");
if (svm->train())
{
CBinaryLabels* out_labels = CLabelsFactory::to_binary(svm->apply());
CROCEvaluation* eval = new CROCEvaluation();
float64_t auPRC = eval->evaluate(out_labels, labels);
SG_SPRINT("training auROC = %f", auPRC);
//SG_UNREF(out_labels);
CPRCEvaluation* eval2 = new CPRCEvaluation();
auPRC = eval2->evaluate(out_labels, labels);
SG_SPRINT(" ----- training auPRC = %f\n", auPRC);
SG_UNREF(out_labels);
SG_REF(tzer);
SG_REF(test_string_feats);
out_labels = CLabelsFactory::to_binary(svm->apply(test_feats));
auPRC = eval->evaluate(out_labels, test_labels);
SG_SPRINT("test auROC = %f", auPRC);
SG_UNREF(eval);
//SG_UNREF(test_feats);
//SG_UNREF(out_labels);
auPRC = eval2->evaluate(out_labels, test_labels);
SG_SPRINT(" ----- test auPRC = %f\n", auPRC);
SG_UNREF(out_labels);
}
else
SG_SERROR("Training machine was unsuccesful..\n");
SG_UNREF(svm);
t = clock() - t;
SG_SPRINT("It took %d clicks or %f seconds\n\n",t,((float)t)/CLOCKS_PER_SEC);
//}
}
exit_shogun();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment