Skip to content

Instantly share code, notes, and snippets.

@soulslicer
Last active February 20, 2017 06:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soulslicer/b19458b8e6285c4c2e933df0c9376f08 to your computer and use it in GitHub Desktop.
Save soulslicer/b19458b8e6285c4c2e933df0c9376f08 to your computer and use it in GitHub Desktop.
Diff that makes a CMakeLists file for LibSVM and adds OpenMP Support
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e69de29..4600a6b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -0,0 +1,44 @@
+cmake_minimum_required(VERSION 2.8.3)
+project(libsvm)
+
+include(CheckCXXCompilerFlag)
+CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11)
+CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X)
+if(COMPILER_SUPPORTS_CXX11)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
+elseif(COMPILER_SUPPORTS_CXX0X)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x")
+else()
+ message(STATUS "The compiler ${CMAKE_CXX_COMPILER} has no C++11 support. Please use a different C++ compiler.")
+endif()
+
+find_package(OpenMP)
+if (OPENMP_FOUND)
+ set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
+ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
+endif()
+
+include_directories(
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${CMAKE_CURRENT_SOURCE_DIR}/python
+)
+
+
+file(GLOB libsvm_SRC
+ "svm.h"
+ "svm.cpp"
+)
+
+add_library(svm SHARED ${libsvm_SRC})
+target_link_libraries(
+ svm
+)
+
+set(MYLIB_VERSION_MAJOR 2)
+set(MYLIB_VERSION_STRING ${MYLIB_VERSION_MAJOR})
+set_target_properties(svm PROPERTIES VERSION ${MYLIB_VERSION_STRING} SOVERSION ${MYLIB_VERSION_MAJOR})
+
+install(FILES python/svmutil.py python/svm.py DESTINATION lib/python2.7/dist-packages)
+install(FILES svm.h DESTINATION include)
+install(TARGETS svm DESTINATION lib)
+
diff --git a/python/svm.py b/python/svm.py
index 577160d..f2d1677 100644
--- a/python/svm.py
+++ b/python/svm.py
@@ -19,7 +19,7 @@ try:
if sys.platform == 'win32':
libsvm = CDLL(path.join(dirname, r'..\windows\libsvm.dll'))
else:
- libsvm = CDLL(path.join(dirname, '../libsvm.so.2'))
+ libsvm = CDLL(path.join(dirname, '../../libsvm.so.2'))
except:
# For unix the prefix 'lib' is not considered.
if find_library('svm'):
diff --git a/svm.cpp b/svm.cpp
index 2bfae57..ebf5326 100644
--- a/svm.cpp
+++ b/svm.cpp
@@ -1282,6 +1282,7 @@ public:
int start, j;
if((start = cache->get_data(i,&data,len)) < len)
{
+ #pragma omp parallel for private(j) schedule(guided)
for(j=start;j<len;j++)
data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
}
@@ -2507,6 +2508,7 @@ double svm_predict_values(const svm_model *model, const svm_node *x, double* dec
{
double *sv_coef = model->sv_coef[0];
double sum = 0;
+ #pragma omp parallel for private(i) reduction(+:sum) schedule(guided)
for(i=0;i<model->l;i++)
sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
sum -= model->rho[0];
@@ -2523,6 +2525,7 @@ double svm_predict_values(const svm_model *model, const svm_node *x, double* dec
int l = model->l;
double *kvalue = Malloc(double,l);
+ #pragma omp parallel for private(i) schedule(guided)
for(i=0;i<l;i++)
kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
@@ -2634,6 +2637,51 @@ double svm_predict_probability(
return svm_predict(model, x);
}
+double svm_predict_probability_ex(
+ const svm_model *model, const svm_node *x, double *prob_estimates)
+{
+ if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
+ model->probA!=NULL && model->probB!=NULL)
+ {
+ int i;
+ int nr_class = model->nr_class;
+ double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
+ svm_predict_values(model, x, dec_values);
+
+ double min_prob=1e-7;
+ double **pairwise_prob=Malloc(double *,nr_class);
+ for(i=0;i<nr_class;i++)
+ pairwise_prob[i]=Malloc(double,nr_class);
+ int k=0;
+ for(i=0;i<nr_class;i++)
+ for(int j=i+1;j<nr_class;j++)
+ {
+ pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
+ pairwise_prob[j][i]=1-pairwise_prob[i][j];
+ k++;
+ }
+ if (nr_class == 2)
+ {
+ prob_estimates[0] = pairwise_prob[0][1];
+ prob_estimates[1] = pairwise_prob[1][0];
+ }
+ else
+ multiclass_probability(nr_class,pairwise_prob,prob_estimates);
+
+ int prob_max_idx = 0;
+ for(i=1;i<nr_class;i++)
+ if(prob_estimates[i] > prob_estimates[prob_max_idx])
+ prob_max_idx = i;
+ for(i=0;i<nr_class;i++)
+ free(pairwise_prob[i]);
+ free(dec_values);
+ free(pairwise_prob);
+ return model->label[prob_max_idx];
+ }
+ else
+ return svm_predict(model, x);
+}
+
static const char *svm_type_table[] =
{
"c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
@@ -2884,8 +2932,117 @@ bool read_model_header(FILE *fp, svm_model* model)
}
+svm_model *svm_load_model_ex(const char *model_file_name){
+ printf("log\n");
+ FILE *fp = fopen(model_file_name,"rb");
+ if(fp==NULL) return NULL;
+
+ char *old_locale = setlocale(LC_ALL, NULL);
+ if (old_locale) {
+ old_locale = strdup(old_locale);
+ }
+ setlocale(LC_ALL, "C");
+
+ // read parameters
+
+ svm_model *model = Malloc(svm_model,1);
+ model->rho = NULL;
+ model->probA = NULL;
+ model->probB = NULL;
+ model->sv_indices = NULL;
+ model->label = NULL;
+ model->nSV = NULL;
+
+ // read header
+ if (!read_model_header(fp, model))
+ {
+ fprintf(stderr, "ERROR: fscanf failed to read model\n");
+ setlocale(LC_ALL, old_locale);
+ free(old_locale);
+ free(model->rho);
+ free(model->label);
+ free(model->nSV);
+ free(model);
+ return NULL;
+ }
+
+ // read sv_coef and SV
+
+ int elements = 0;
+ long pos = ftell(fp);
+
+ max_line_len = 1024;
+ line = Malloc(char,max_line_len);
+ char *p,*endptr,*idx,*val;
+
+ while(readline(fp)!=NULL)
+ {
+ p = strtok(line,":");
+ while(1)
+ {
+ p = strtok(NULL,":");
+ if(p == NULL)
+ break;
+ ++elements;
+ }
+ }
+ elements += model->l;
+
+ fseek(fp,pos,SEEK_SET);
+
+ int m = model->nr_class - 1;
+ int l = model->l;
+ model->sv_coef = Malloc(double *,m);
+ int i;
+ for(i=0;i<m;i++)
+ model->sv_coef[i] = Malloc(double,l);
+ model->SV = Malloc(svm_node*,l);
+ svm_node *x_space = NULL;
+ if(l>0) x_space = Malloc(svm_node,elements);
+
+ int j=0;
+ for(i=0;i<l;i++)
+ {
+ readline(fp);
+ model->SV[i] = &x_space[j];
+
+ p = strtok(line, " \t");
+ model->sv_coef[0][i] = strtod(p,&endptr);
+ for(int k=1;k<m;k++)
+ {
+ p = strtok(NULL, " \t");
+ model->sv_coef[k][i] = strtod(p,&endptr);
+ }
+
+ while(1)
+ {
+ idx = strtok(NULL, ":");
+ val = strtok(NULL, " \t");
+
+ if(val == NULL)
+ break;
+ x_space[j].index = (int) strtol(idx,&endptr,10);
+ x_space[j].value = strtod(val,&endptr);
+
+ ++j;
+ }
+ x_space[j++].index = -1;
+ }
+ free(line);
+
+ setlocale(LC_ALL, old_locale);
+ free(old_locale);
+
+ if (ferror(fp) != 0 || fclose(fp) != 0)
+ return NULL;
+
+ model->free_sv = 1; // XXX
+ return model;
+}
+
svm_model *svm_load_model(const char *model_file_name)
{
+ printf("log\n");
FILE *fp = fopen(model_file_name,"rb");
if(fp==NULL) return NULL;
diff --git a/svm.h b/svm.h
index 9251ea8..d4a38cc 100644
--- a/svm.h
+++ b/svm.h
@@ -76,6 +76,7 @@ void svm_cross_validation(const struct svm_problem *prob, const struct svm_param
int svm_save_model(const char *model_file_name, const struct svm_model *model);
struct svm_model *svm_load_model(const char *model_file_name);
+struct svm_model *svm_load_model_ex(const char *model_file_name);
int svm_get_svm_type(const struct svm_model *model);
int svm_get_nr_class(const struct svm_model *model);
@@ -87,6 +88,7 @@ double svm_get_svr_probability(const struct svm_model *model);
double svm_predict_values(const struct svm_model *model, const struct svm_node *x, double* dec_values);
double svm_predict(const struct svm_model *model, const struct svm_node *x);
double svm_predict_probability(const struct svm_model *model, const struct svm_node *x, double* prob_estimates);
+double svm_predict_probability_ex(const struct svm_model *model, const struct svm_node *x, double* prob_estimates);
void svm_free_model_content(struct svm_model *model_ptr);
void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment