Last active
December 20, 2015 09:19
-
-
Save sonney2k/6106780 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <shogun/base/init.h> | |
#include <shogun/classifier/svm/SVMOcas.h> | |
#include <shogun/features/HashedDocDotFeatures.h> | |
#include <shogun/features/DenseFeatures.h> | |
#include <shogun/lib/NGramTokenizer.h> | |
#include <shogun/classifier/svm/SVMLight.h> | |
#include <shogun/lib/DelimiterTokenizer.h> | |
#include <shogun/io/LineReader.h> | |
#include <shogun/labels/BinaryLabels.h> | |
#include <shogun/classifier/svm/LibLinear.h> | |
#include <shogun/preprocessor/SortUlongString.h> | |
#include <shogun/kernel/string/CommUlongStringKernel.h> | |
#include <shogun/evaluation/PRCEvaluation.h> | |
#include <shogun/evaluation/ROCEvaluation.h> | |
#include <cstdio> | |
#include <fstream> | |
#include <iostream> | |
#include <ctime> | |
using namespace shogun; | |
CBinaryLabels* load_labels(int32_t number_of_examples) | |
{ | |
SGVector<int32_t> l_vector(number_of_examples); | |
const char* labels_filepath = "/home/user/data/webspam_train.lab"; | |
std::ifstream lin(labels_filepath); | |
for (index_t doc_id=0; doc_id<number_of_examples; doc_id++) | |
lin >> l_vector.vector[doc_id]; | |
lin.close(); | |
return new CBinaryLabels(l_vector); | |
} | |
CBinaryLabels* load_test_labels() | |
{ | |
SGVector<int32_t> l_vector(500); | |
const char* labels_filepath = "/home/user/data/webspam_train.lab"; | |
std::ifstream lin(labels_filepath); | |
int32_t line; | |
for (index_t i=0; i<200000; i++) | |
lin >> line; | |
for (index_t doc_id=0; doc_id<500; doc_id++) | |
lin >> l_vector.vector[doc_id]; | |
lin.close(); | |
return new CBinaryLabels(l_vector); | |
} | |
CStringFeatures<char>* load_test_data() | |
{ | |
SGString<char>* strings = SG_MALLOC(SGString<char>, 500); | |
const char* train_filepath = "/home/user/data/webspam_train.dat"; | |
const int32_t doc_size = 5 * 1024 * 1024; | |
FILE* fin=fopen(train_filepath, "rb"); | |
CDelimiterTokenizer* tzer = new CDelimiterTokenizer(); | |
tzer->delimiters['\0'] = 1; | |
CLineReader* reader=new CLineReader(doc_size, fin, tzer); | |
for (index_t i=0; i<200000; i++) | |
{ | |
if (i%10000==0) | |
SG_SPRINT("Skipped %d\n", i); | |
reader->read_line(); | |
} | |
SGVector<char> tmp_string; | |
for (index_t doc_id=0; doc_id<500 && reader->has_next(); doc_id++) | |
{ | |
tmp_string = reader->read_line(); | |
strings[doc_id] = SGString<char>(tmp_string.vlen); | |
memcpy(strings[doc_id].string, tmp_string.vector, tmp_string.vlen); | |
} | |
fclose(fin); | |
SG_UNREF(tzer); | |
SG_UNREF(reader); | |
SGStringList<char> str_list(strings, 500, doc_size); | |
return new CStringFeatures<char>(str_list, RAWBYTE); | |
} | |
CStringFeatures<char>* load_data(int32_t number_of_examples) | |
{ | |
SGString<char>* strings = SG_MALLOC(SGString<char>, number_of_examples); | |
const char* train_filepath = "/home/user/data/webspam_train.dat"; | |
const int32_t doc_size = 5 * 1024 * 1024; | |
FILE* fin=fopen(train_filepath, "rb"); | |
CDelimiterTokenizer* tzer = new CDelimiterTokenizer(); | |
tzer->delimiters['\0'] = 1; | |
CLineReader* reader=new CLineReader(doc_size, fin, tzer); | |
SGVector<char> tmp_string; | |
for (index_t doc_id=0; doc_id<number_of_examples && reader->has_next(); doc_id++) | |
{ | |
tmp_string = reader->read_line(); | |
strings[doc_id] = SGString<char>(tmp_string.vlen); | |
memcpy(strings[doc_id].string, tmp_string.vector, tmp_string.vlen); | |
} | |
fclose(fin); | |
SG_UNREF(tzer); | |
SG_UNREF(reader); | |
SGStringList<char> str_list(strings, number_of_examples, doc_size); | |
return new CStringFeatures<char>(str_list, RAWBYTE); | |
} | |
int main(int argv, char** argc) | |
{ | |
init_shogun_with_defaults(); | |
int32_t number_of_examples[] = {50, 10000, 50000, 100000}; | |
CNGramTokenizer* tzer = new CNGramTokenizer(8); | |
SG_SPRINT("Loading test data\n"); | |
CStringFeatures<char>* test_string_feats = load_test_data(); | |
SG_SPRINT("Test data loaded\n"); | |
CStringFeatures<uint64_t>* test_feats = new CStringFeatures<uint64_t>(test_string_feats->get_alphabet()); | |
test_feats->obtain_from_char(test_string_feats, 8-1, 8, 0, false); | |
CSortUlongString* preproc = new CSortUlongString(); | |
preproc->init(test_feats); | |
test_feats->add_preprocessor(preproc); | |
test_feats->apply_preprocessor(); // don't do on-the-fly preprocessing: in-place preprocess the data now | |
CBinaryLabels* test_labels = load_test_labels(); | |
for (index_t i=0; i<4; i++) | |
{ | |
int32_t n = number_of_examples[i]; | |
SG_SPRINT("Reading data...\n"); | |
CStringFeatures<char>* string_feats = load_data(n); | |
CBinaryLabels* labels = load_labels(n); | |
SG_SPRINT("Data loaded.\n"); | |
CStringFeatures<uint64_t>* feats = new CStringFeatures<uint64_t>(string_feats->get_alphabet()); | |
feats->obtain_from_char(string_feats, 8-1, 8, 0, false); | |
CSortUlongString* preproc2 = new CSortUlongString(); | |
preproc2->init(feats); | |
feats->add_preprocessor(preproc2); | |
feats->apply_preprocessor(); | |
CCommUlongStringKernel* kernel = new CCommUlongStringKernel(feats,feats,false); | |
CSVM* svm=new CSVMLight(1, kernel, labels); | |
svm->set_epsilon(0.01); | |
clock_t t = clock(); | |
SG_REF(labels); | |
SG_SPRINT("Start training\n"); | |
if (svm->train()) | |
{ | |
CBinaryLabels* out_labels = CLabelsFactory::to_binary(svm->apply()); | |
CROCEvaluation* eval = new CROCEvaluation(); | |
float64_t auPRC = eval->evaluate(out_labels, labels); | |
SG_SPRINT("training auROC = %f", auPRC); | |
//SG_UNREF(out_labels); | |
CPRCEvaluation* eval2 = new CPRCEvaluation(); | |
auPRC = eval2->evaluate(out_labels, labels); | |
SG_SPRINT(" ----- training auPRC = %f\n", auPRC); | |
SG_UNREF(out_labels); | |
SG_REF(tzer); | |
SG_REF(test_string_feats); | |
out_labels = CLabelsFactory::to_binary(svm->apply(test_feats)); | |
auPRC = eval->evaluate(out_labels, test_labels); | |
SG_SPRINT("test auROC = %f", auPRC); | |
SG_UNREF(eval); | |
//SG_UNREF(test_feats); | |
//SG_UNREF(out_labels); | |
auPRC = eval2->evaluate(out_labels, test_labels); | |
SG_SPRINT(" ----- test auPRC = %f\n", auPRC); | |
SG_UNREF(out_labels); | |
} | |
else | |
SG_SERROR("Training machine was unsuccesful..\n"); | |
SG_UNREF(svm); | |
t = clock() - t; | |
SG_SPRINT("It took %d clicks or %f seconds\n\n",t,((float)t)/CLOCKS_PER_SEC); | |
//} | |
} | |
exit_shogun(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment