Skip to content

Instantly share code, notes, and snippets.

@tomhicks
Created February 20, 2020 17:50
Show Gist options
  • Save tomhicks/bd2f7a7fe90744e7e0183632ae8ef58f to your computer and use it in GitHub Desktop.
Save tomhicks/bd2f7a7fe90744e7e0183632ae8ef58f to your computer and use it in GitHub Desktop.
Node 12 node-svm fix
#include "node-svm.h"
#include "training-worker.h"
#include "prediction-worker.h"
#include "probability-prediction-worker.h"
using v8::FunctionTemplate;
using v8::Object;
using v8::String;
using v8::Array;
Nan::Persistent<Function> NodeSvm::constructor;
NodeSvm::~NodeSvm(){
}
NAN_METHOD(NodeSvm::New) {
Nan::HandleScope scope;
if (info.IsConstructCall()) {
// Invoked as constructor: `new MyObject(...)`
NodeSvm* obj = new NodeSvm();
obj->Wrap(info.This());
info.GetReturnValue().Set(info.This());
}
else {
// Invoked as plain function `MyObject(...)`, turn into construct call.
const int argc = 0;
#ifdef _WIN32
// On windows you get "error C2466: cannot allocate an array of constant size 0" and we use a pointer
Local<Value>* argv;
#else
Local<Value> argv[argc];
#endif
Local<Function> cons = Nan::New<Function>(constructor);
info.GetReturnValue().Set(Nan::NewInstance(cons, argc, argv).ToLocalChecked());
}
}
NAN_METHOD(NodeSvm::SetParameters) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
assert(info[0]->IsObject());
Local<Object> params = info[0].As<Object>();
obj->setParameters(params);
}
NAN_METHOD(NodeSvm::Train) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->hasParameters());
// chech params
assert(info[0]->IsObject());
Local<Array> dataset = info[0].As<Array>();
obj->setSvmProblem(dataset);
obj->train();
}
NAN_METHOD(NodeSvm::GetModel) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->isTrained());
info.GetReturnValue().Set(obj->getModel());
}
NAN_METHOD(NodeSvm::TrainAsync) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->hasParameters());
// chech params
assert(info[0]->IsObject());
assert(info[1]->IsFunction());
Local<Array> dataset = info[0].As<Array>();
Nan::Callback *callback = new Nan::Callback(info[1].As<Function>());
Nan::AsyncQueueWorker(new TrainingWorker(obj, dataset, callback));
}
NAN_METHOD(NodeSvm::GetKernelType) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->hasParameters());
info.GetReturnValue().Set(Nan::New<Number>(obj->getKernelType()));
}
NAN_METHOD(NodeSvm::GetSvmType) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->hasParameters());
info.GetReturnValue().Set(Nan::New<Number>(obj->getSvmType()));
}
NAN_METHOD(NodeSvm::IsTrained) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
info.GetReturnValue().Set(Nan::New<Boolean>(obj->isTrained()));
}
NAN_METHOD(NodeSvm::GetLabels) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->isTrained());
// Create a new empty array.
int nbClasses = obj->getClassNumber();
Local<Array> labels = Nan::New<Array>(nbClasses);
for (int j=0; j < nbClasses; j++){
labels->Set(j, Nan::New<Number>(obj->getLabel(j)));
}
info.GetReturnValue().Set(labels);
}
NAN_METHOD(NodeSvm::SetModel) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
assert(info.Length() == 1);
assert(info[0]->IsObject());
Local<Array> model = info[0].As<Array>();
obj->setModel(model);
}
NAN_METHOD(NodeSvm::Predict) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->isTrained());
// chech params
assert(info[0]->IsObject());
Local<Array> inputs = info[0].As<Array>();
assert(inputs->IsArray());
assert(inputs->Length() > 0);
svm_node *x = new svm_node[inputs->Length() + 1];
obj->getSvmNodes(inputs, x);
double prediction = obj->predict(x);
delete[] x;
info.GetReturnValue().Set(Nan::New<Number>(prediction));
}
NAN_METHOD(NodeSvm::PredictAsync) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->isTrained());
// chech params
assert(info[0]->IsObject());
Local<Array> inputs = info[0].As<Array>();
assert(inputs->IsArray());
assert(inputs->Length() > 0);
assert(info[1]->IsFunction());
Nan::Callback *callback = new Nan::Callback(info[1].As<Function>());
Nan::AsyncQueueWorker(new PredictionWorker(obj, inputs, callback));
}
NAN_METHOD(NodeSvm::PredictProbabilities) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->isTrained());
// chech params
assert(info[0]->IsObject());
Local<Array> inputs = info[0].As<Array>();
assert(inputs->IsArray());
assert(inputs->Length() > 0);
svm_node *x = new svm_node[inputs->Length() + 1];
obj->getSvmNodes(inputs, x);
int nbClass = obj->getClassNumber();
double *prob_estimates = new double[nbClass];
obj->predictProbabilities(x, prob_estimates);
// Create the result array
Local<Array> probs = Nan::New<Array>(nbClass);
for (int j=0; j < nbClass; j++){
probs->Set(j, Nan::New<Number>(prob_estimates[j]));
}
delete[] prob_estimates;
delete[] x;
info.GetReturnValue().Set(probs);
}
NAN_METHOD(NodeSvm::PredictProbabilitiesAsync) {
Nan::HandleScope scope;
NodeSvm *obj = Nan::ObjectWrap::Unwrap<NodeSvm>(info.This());
// check obj
assert(obj->isTrained());
// chech params
assert(info[0]->IsObject());
Local<Array> inputs = info[0].As<Array>();
assert(inputs->IsArray());
assert(inputs->Length() > 0);
assert(info[1]->IsFunction());
Nan::Callback *callback = new Nan::Callback(info[1].As<Function>());
Nan::AsyncQueueWorker(new ProbabilityPredictionWorker(obj, inputs, callback));
}
void NodeSvm::Init(Local<Object> exports){
// Prepare constructor template
Local<FunctionTemplate> tpl = Nan::New<FunctionTemplate>(NodeSvm::New);
tpl->SetClassName(Nan::New<String>("NodeSvm").ToLocalChecked());
tpl->InstanceTemplate()->SetInternalFieldCount(1);
// prototype
tpl->PrototypeTemplate()->Set(Nan::New<String>("setParameters").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::SetParameters));
tpl->PrototypeTemplate()->Set(Nan::New<String>("train").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::Train));
tpl->PrototypeTemplate()->Set(Nan::New<String>("trainAsync").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::TrainAsync));
tpl->PrototypeTemplate()->Set(Nan::New<String>("isTrained").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::IsTrained));
tpl->PrototypeTemplate()->Set(Nan::New<String>("getLabels").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::GetLabels));
tpl->PrototypeTemplate()->Set(Nan::New<String>("getSvmType").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::GetSvmType));
tpl->PrototypeTemplate()->Set(Nan::New<String>("getKernelType").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::GetKernelType));
tpl->PrototypeTemplate()->Set(Nan::New<String>("predict").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::Predict));
tpl->PrototypeTemplate()->Set(Nan::New<String>("predictAsync").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::PredictAsync));
tpl->PrototypeTemplate()->Set(Nan::New<String>("predictProbabilities").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::PredictProbabilities));
tpl->PrototypeTemplate()->Set(Nan::New<String>("predictProbabilitiesAsync").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::PredictProbabilitiesAsync));
tpl->PrototypeTemplate()->Set(Nan::New<String>("loadFromModel").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::SetModel));
tpl->PrototypeTemplate()->Set(Nan::New<String>("getModel").ToLocalChecked(),
Nan::New<FunctionTemplate>(NodeSvm::GetModel));
//constructor = Persistent<Function>::New(tpl->GetFunction());
exports->Set(Nan::New<String>("NodeSvm").ToLocalChecked(), tpl->GetFunction(Nan::GetCurrentContext()).ToLocalChecked());
constructor.Reset(Nan::GetFunction(tpl).ToLocalChecked());
}
#ifndef _NODE_SVM_H
#define _NODE_SVM_H
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <iostream>
#include <node.h>
#include <assert.h>
#include <nan.h>
#include "../libsvm/svm.h"
using namespace v8;
class NodeSvm : public Nan::ObjectWrap
{
public:
static void Init(Local<Object> exports);
static NAN_METHOD(SetParameters);
static NAN_METHOD(Train);
static NAN_METHOD(TrainAsync);
static NAN_METHOD(IsTrained);
static NAN_METHOD(GetLabels);
static NAN_METHOD(GetKernelType);
static NAN_METHOD(GetSvmType);
static NAN_METHOD(Predict);
static NAN_METHOD(PredictAsync);
static NAN_METHOD(PredictProbabilities);
static NAN_METHOD(PredictProbabilitiesAsync);
static NAN_METHOD(SaveToFile);
static NAN_METHOD(LoadFromFile);
static NAN_METHOD(SetModel);
static NAN_METHOD(GetModel);
static NAN_METHOD(New);
bool isTrained(){ return model != NULL;}
bool hasParameters(){ return params != NULL;}
bool isClassificationSVM(){
if(!hasParameters()){
return false;
}
int svm_type = params->svm_type;
if (svm_type==NU_SVR || svm_type==EPSILON_SVR){
return false;
}
return true;
};
bool isRegressionSVM(){ return !isClassificationSVM();};
int getSvmType(){ return params->svm_type; };
int getKernelType(){ return params->kernel_type; };
int getClassNumber(){
if(model==NULL){
return 0;
}
return model->nr_class;
}
double getLabel(int index){ return model==NULL ? 0.0 : model->label[index]; };
int saveModel(const char *fileName){
return svm_save_model(fileName, model);
};
void loadModelFromFile(const char *fileName){
model = svm_load_model(fileName);
assert(model!=NULL);
params = &model->param;
assert(params!=NULL);
};
void setParameters(Local<Object> obj){
struct svm_parameter *svm_params = new svm_parameter();
svm_params->nr_weight = 0;
svm_params->weight_label = NULL;
svm_params->weight = NULL;
// check classifer and its options
Local<String> svm_type_name = Nan::New<String>("svmType").ToLocalChecked();
assert(Nan::Has(obj, svm_type_name).FromJust());
svm_params->svm_type = Nan::Get(obj, svm_type_name).ToLocalChecked()->IntegerValue(Nan::GetCurrentContext()).FromJust();
assert(svm_params->svm_type == C_SVC ||
svm_params->svm_type == NU_SVC ||
svm_params->svm_type == ONE_CLASS ||
svm_params->svm_type == EPSILON_SVR ||
svm_params->svm_type == NU_SVR);
if (svm_params->svm_type == C_SVC ||
svm_params->svm_type == EPSILON_SVR ||
svm_params->svm_type == NU_SVR){
Local<String> str_c = Nan::New<String>("c").ToLocalChecked();
assert(Nan::Has(obj, str_c).FromJust());
svm_params->C = Nan::Get(obj, str_c).ToLocalChecked()->NumberValue(Nan::GetCurrentContext()).FromJust();
}
if (svm_params->svm_type == NU_SVC ||
svm_params->svm_type == NU_SVR ||
svm_params->svm_type == ONE_CLASS){
Local<String> str_nu = Nan::New<String>("nu").ToLocalChecked();
assert(Nan::Has(obj, str_nu).FromJust());
svm_params->nu = Nan::Get(obj, str_nu).ToLocalChecked()->NumberValue(Nan::GetCurrentContext()).FromJust();
assert(svm_params->nu > 0 &&
svm_params->nu <= 1);
}
if (svm_params->svm_type == EPSILON_SVR){
Local<String> str_epsilon = Nan::New<String>("epsilon").ToLocalChecked();
assert(Nan::Has(obj, str_epsilon).FromJust());
svm_params->p = Nan::Get(obj, str_epsilon).ToLocalChecked()->NumberValue(Nan::GetCurrentContext()).FromJust();
assert(svm_params->p >= 0);
}
// check kernel and its options
Local<String> str_kernel_type = Nan::New<String>("kernelType").ToLocalChecked();
assert(Nan::Has(obj, str_kernel_type).FromJust());
svm_params->kernel_type = Nan::Get(obj, str_kernel_type).ToLocalChecked()->IntegerValue(Nan::GetCurrentContext()).FromJust();
assert(svm_params->kernel_type == LINEAR ||
svm_params->kernel_type == POLY ||
svm_params->kernel_type == RBF ||
// svm_params->kernel_type == PRECOMPUTED || // not supported (yet)
svm_params->kernel_type == SIGMOID);
if (svm_params->kernel_type == POLY){
Local<String> str_degree = Nan::New<String>("degree").ToLocalChecked();
assert(Nan::Has(obj, str_degree).FromJust());
svm_params->degree = Nan::Get(obj, str_degree).ToLocalChecked()->IntegerValue(Nan::GetCurrentContext()).FromJust();
assert(svm_params->degree >= 0);
}
if (svm_params->kernel_type == POLY ||
svm_params->kernel_type == RBF ||
svm_params->kernel_type == SIGMOID){
Local<String> str_gamma = Nan::New<String>("gamma").ToLocalChecked();
assert(Nan::Has(obj, str_gamma).FromJust());
svm_params->gamma = Nan::Get(obj, str_gamma).ToLocalChecked()->NumberValue(Nan::GetCurrentContext()).FromJust();
assert(svm_params->gamma >= 0);
}
if (svm_params->kernel_type == POLY ||
svm_params->kernel_type == SIGMOID){
Local<String> str_r = Nan::New<String>("r").ToLocalChecked();
assert(Nan::Has(obj, str_r).FromJust());
svm_params->coef0 = Nan::Get(obj, str_r).ToLocalChecked()->NumberValue(Nan::GetCurrentContext()).FromJust();
}
// check training options
Local<String> str_cache_size = Nan::New<String>("cacheSize").ToLocalChecked();
svm_params->cache_size = Nan::Has(obj, str_cache_size).FromJust() ?
Nan::Get(obj, str_cache_size).ToLocalChecked()->NumberValue(Nan::GetCurrentContext()).FromJust() :
100;
assert(svm_params->cache_size > 0);
Local<String> str_eps = Nan::New<String>("eps").ToLocalChecked();
svm_params->eps = Nan::Has(obj, str_eps).FromJust() ?
Nan::Get(obj, str_eps).ToLocalChecked()->NumberValue(Nan::GetCurrentContext()).FromJust() :
1e-3;
assert(svm_params->eps > 0);
Local<String> str_shrinking = Nan::New<String>("shrinking").ToLocalChecked();
svm_params->shrinking = // enabled by default
Nan::Has(obj, str_shrinking).FromJust() &&
!Nan::Get(obj, str_shrinking).ToLocalChecked()->BooleanValue(Nan::GetCurrentContext()).FromJust() ? 0 : 1;
Local<String> str_probability = Nan::New<String>("probability").ToLocalChecked();
svm_params->probability = // disabled by default
Nan::Has(obj, str_probability).FromJust() &&
Nan::Get(obj, str_probability).ToLocalChecked()->BooleanValue(Nan::GetCurrentContext()).FromJust() ? 1 : 0;
if (svm_params->svm_type == ONE_CLASS){
assert(svm_params->probability == 0); // one-class SVM probability output not supported (yet)
}
params = svm_params;
};
void setModel(Local<Object> obj){
Local<String> str_params = Nan::New<String>("params").ToLocalChecked();
assert(Nan::Has(obj, str_params).FromJust());
assert(Nan::Get(obj, str_params).ToLocalChecked()->IsObject());
setParameters(Nan::Get(obj, str_params).ToLocalChecked()->ToObject(Nan::GetCurrentContext()).ToLocalChecked());
assert(params!=NULL);
struct svm_model *new_model = new svm_model();
new_model->free_sv = 1; // XXX
new_model->rho = NULL;
new_model->probA = NULL;
new_model->probB = NULL;
new_model->sv_indices = NULL;
new_model->label = NULL;
new_model->nSV = NULL;
Local<String> str_l = Nan::New<String>("l").ToLocalChecked();
assert(Nan::Has(obj, str_l).FromJust());
assert(Nan::Get(obj, str_l).ToLocalChecked()->IsInt32());
new_model->l = Nan::Get(obj, str_l).ToLocalChecked()->IntegerValue(Nan::GetCurrentContext()).FromJust();
Local<String> str_nr_class = Nan::New<String>("nrClass").ToLocalChecked();
new_model->nr_class = Nan::Get(obj, str_nr_class).ToLocalChecked()->IntegerValue(Nan::GetCurrentContext()).FromJust();
unsigned int n = new_model->nr_class * (new_model->nr_class-1)/2;
// rho
Local<String> str_rho = Nan::New<String>("rho").ToLocalChecked();
assert(Nan::Has(obj, str_rho).FromJust());
assert(Nan::Get(obj, str_rho).ToLocalChecked()->IsArray());
Local<Array> rho = Nan::Get(obj, str_rho).ToLocalChecked().As<Array>();
assert(rho->Length()==n);
new_model->rho = new double[n];
for(unsigned int i=0;i<n;i++){
Local<Value> elt = rho->Get(i);
assert(elt->IsNumber());
new_model->rho[i] = elt->NumberValue(Nan::GetCurrentContext()).FromJust();
}
// classes
Local<String> str_labels = Nan::New<String>("labels").ToLocalChecked();
if (Nan::Has(obj, str_labels).FromJust()){
assert(Nan::Get(obj, str_labels).ToLocalChecked()->IsArray());
Local<Array> labels = Nan::Get(obj, str_labels).ToLocalChecked().As<Array>();
//assert(labels->Length()==new_model->nr_class);
new_model->label = new int[new_model->nr_class];
for(int i=0;i<new_model->nr_class;i++){
Local<Value> elt = labels->Get(i);
assert(elt->IsInt32());
new_model->label[i] = elt->IntegerValue(Nan::GetCurrentContext()).FromJust();
}
// nSV
Local<String> str_nb_support_vectors = Nan::New<String>("nbSupportVectors").ToLocalChecked();
assert(Nan::Has(obj, str_nb_support_vectors).FromJust());
assert(Nan::Get(obj, str_nb_support_vectors).ToLocalChecked()->IsArray());
Local<Array> nbSupportVectors = Nan::Get(obj, str_nb_support_vectors).ToLocalChecked().As<Array>();
assert((int)nbSupportVectors->Length() == new_model->nr_class);
new_model->nSV = new int[new_model->nr_class];
for (int i=0;i<new_model->nr_class;i++){
Local<Value> elt = nbSupportVectors->Get(i);
assert(elt->IsInt32());
new_model->nSV[i] = elt->IntegerValue(Nan::GetCurrentContext()).FromJust();
}
}
// probA
Local<String> str_prob_a = Nan::New<String>("probA").ToLocalChecked();
if (Nan::Has(obj, str_prob_a).FromJust()){
assert(Nan::Get(obj, str_prob_a).ToLocalChecked()->IsArray());
Local<Array> probA = Nan::Get(obj, str_prob_a).ToLocalChecked().As<Array>();
assert(probA->Length()==n);
new_model->probA = new double[n];
for(unsigned int i=0;i<n;i++){
Local<Value> elt = probA->Get(i);
assert(elt->IsNumber());
new_model->probA[i] = elt->NumberValue(Nan::GetCurrentContext()).FromJust();
}
}
// probB
Local<String> str_prob_b = Nan::New<String>("probB").ToLocalChecked();
if (Nan::Has(obj, str_prob_b).FromJust()){
assert(Nan::Get(obj, str_prob_b).ToLocalChecked()->IsArray());
Local<Array> probB = Nan::Get(obj, str_prob_b).ToLocalChecked().As<Array>();
assert(probB->Length()==n);
new_model->probB = new double[n];
for(unsigned int i=0;i<n;i++){
Local<Value> elt = probB->Get(i);
assert(elt->IsNumber());
new_model->probB[i] = elt->NumberValue(Nan::GetCurrentContext()).FromJust();
}
}
// SV
Local<String> str_support_vectors = Nan::New<String>("supportVectors").ToLocalChecked();
assert(Nan::Has(obj, str_support_vectors).FromJust());
assert(Nan::Get(obj, str_support_vectors).ToLocalChecked()->IsArray());
Local<Array> supportVectors = Nan::Get(obj, str_support_vectors).ToLocalChecked().As<Array>();
assert((int)supportVectors->Length() == new_model->l);
int m = new_model->nr_class - 1;
int l = new_model->l;
new_model->sv_coef = new double *[m];
for(int i=0; i < m ;i++)
new_model->sv_coef[i] = new double[l];
new_model->SV = new svm_node*[l];
for(int i = 0; i < l; i++) {
Local<Array> ex = supportVectors->Get(i).As<Array>();
assert(ex->Length()==2);
Local<Array> x = ex->Get(0).As<Array>();
Local<Array> y = ex->Get(1).As<Array>();
new_model->SV[i] = new svm_node[x->Length() + 1];
for(unsigned j = 0; j < x->Length(); ++j) {
new_model->SV[i][j].index = j+1;
new_model->SV[i][j].value = x->Get(j)->NumberValue(Nan::GetCurrentContext()).FromJust();
}
new_model->SV[i][x->Length()].index = -1;
for(int j=0; j < m ;j++)
new_model->sv_coef[j][i] = y->Get(j)->NumberValue(Nan::GetCurrentContext()).FromJust();
}
model = new_model;
model->param = *params;
assert(model!=NULL);
};
void setSvmProblem(Local<Array> dataset){
Nan::HandleScope scope;
struct svm_problem *prob = new svm_problem();
prob->l = 0;
assert(dataset->Length() > 0);
// check data structure and assign Y
prob->l= dataset->Length();
prob->y = new double[prob->l];
int nb_features = -1;
for (unsigned i=0; i < dataset->Length(); i++) {
Local<Value> t = dataset->Get(i);
assert(t->IsArray());
Local<Array> ex = t.As<Array>();
assert(ex->Length() == 2);
Local<Value> tin = ex->Get(0);
Local<Value> tout = ex->Get(1);
assert(tin->IsArray());
assert(tout->IsNumber());
Local<Array> x = tin.As<Array>();
if (nb_features == -1){
nb_features = x->Length();
}
else {
assert(nb_features == (int)x->Length()); // Incorrect dataset: all Xies should have the same length)
}
}
// Asign X and Y
prob->x = new svm_node*[dataset->Length()];
for (unsigned i = 0; i < dataset->Length(); i++) {
prob->x[i] = new svm_node[nb_features + 1];
Local<Array> ex = dataset->Get(i).As<Array>();
Local<Array> x = ex->Get(0).As<Array>();
for (unsigned j = 0; j < x->Length(); ++j) {
double xi = x->Get(j)->NumberValue(Nan::GetCurrentContext()).FromJust();
prob->x[i][j].index = j+1;
prob->x[i][j].value = xi;
}
prob->x[i][x->Length()].index = -1;
double y = ex->Get(1)->NumberValue(Nan::GetCurrentContext()).FromJust();
prob->y[i] = y;
}
trainingProblem = prob;
};
void train(){
model = svm_train(trainingProblem, params);
};
double predict(svm_node *x){
return svm_predict(model, x);
}
void predictProbabilities(svm_node *x, double* prob_estimates){
svm_predict_probability(model,x,prob_estimates);
};
void getSvmNodes(Local<Array> inputs, svm_node *nodes){
for (unsigned j=0; j < inputs->Length(); j++){
double xi = inputs->Get(j)->NumberValue(Nan::GetCurrentContext()).FromJust();
nodes[j].index = j+1;
nodes[j].value = xi;
}
nodes[inputs->Length()].index = -1;
};
Local<Object> getModel(){
Local<Object> obj = Nan::New<Object>();
Local<String> str_nr_class = Nan::New<String>("nrClass").ToLocalChecked();
Local<String> str_l = Nan::New<String>("l").ToLocalChecked();
obj->Set(str_nr_class, Nan::New<Number>(model->nr_class));
obj->Set(str_l, Nan::New<Number>(model->l));
// Create a new array for support vectors
Local<Array> supportVectors = Nan::New<Array>(model->l);
const double * const *sv_coef = model->sv_coef;
const svm_node * const *SV = model->SV;
const int nb_outputs = model->nr_class - 1;
for (int i=0;i<model->l;i++){
Local<Array> outputs = Nan::New<Array>(nb_outputs);
for (int j=0; j < nb_outputs ; j++)
outputs->Set(j, Nan::New<Number>(sv_coef[j][i]));
const svm_node *p = SV[i];
int max_index = 0;
int nb_index = 0;
while(p->index != -1){
nb_index++;
if (p->index > max_index)
max_index = p->index;
p++;
}
Local<Array> inputs = Nan::New<Array>(nb_index);
int p_i = 0;
for (int k=0; k < max_index ; k++){
if (k+1 == model->SV[i][p_i].index){
inputs->Set(k, Nan::New<Number>(SV[i][p_i].value));
p_i++;
}
else {
inputs->Set(k, Nan::New<Number>(0));
}
}
Local<Array> example = Nan::New<Array>(2);
example->Set(0, inputs);
example->Set(1, outputs);
supportVectors->Set(i, example);
}
Local<String> str_support_vectors = Nan::New<String>("supportVectors").ToLocalChecked();
obj->Set(str_support_vectors, supportVectors);
if (model->nSV) {
Local<Array> nbSupportVectors = Nan::New<Array>(model->nr_class);
for(int i=0; i < model->nr_class ; i++) {
nbSupportVectors->Set(i, Nan::New<Number>(model->nSV[i]));
}
Local<String> str_nb_support_vectors = Nan::New<String>("nbSupportVectors").ToLocalChecked();
obj->Set(str_nb_support_vectors, nbSupportVectors);
}
if (model->label) {
Local<Array> labels = Nan::New<Array>(model->nr_class);
for (int i=0 ; i < model->nr_class ; i++){
labels->Set(i, Nan::New<Number>(model->label[i]));
}
Local<String> str_labels = Nan::New<String>("labels").ToLocalChecked();
obj->Set(str_labels, labels);
}
if (model->probA) { // regression has probA only
int n = model->nr_class*(model->nr_class-1)/2;
Local<Array> probA = Nan::New<Array>(n);
for(int i=0 ; i < n ; i++){
probA->Set(i, Nan::New<Number>(model->probA[i]));
}
Local<String> str_prob_a = Nan::New<String>("probA").ToLocalChecked();
obj->Set(str_prob_a, probA);
}
if (model->probB) {
int n = model->nr_class*(model->nr_class-1)/2;
Local<Array> probB = Nan::New<Array>(n);
for(int i=0 ; i < n ; i++){
probB->Set(i, Nan::New<Number>(model->probB[i]));
}
Local<String> str_prob_b = Nan::New<String>("probB").ToLocalChecked();
obj->Set(str_prob_b, probB);
}
if (model->rho) {
int n = model->nr_class*(model->nr_class-1)/2;
Local<Array> rho = Nan::New<Array>(n);
for (int i=0 ; i < n ; i++){
rho->Set(i, Nan::New<Number>(model->rho[i]));
}
Local<String> str_rho = Nan::New<String>("rho").ToLocalChecked();
obj->Set(str_rho, rho);
}
Local<Object> parameters = Nan::New<Object>();
Local<String> str_svm_type = Nan::New<String>("svmType").ToLocalChecked();
parameters->Set(str_svm_type, Nan::New<Number>(model->param.svm_type));
if (model->param.svm_type == C_SVC ||
model->param.svm_type == EPSILON_SVR ||
model->param.svm_type == NU_SVR){
Local<String> str_c = Nan::New<String>("c").ToLocalChecked();
parameters->Set(str_c, Nan::New<Number>(model->param.C));
}
if (model->param.svm_type == NU_SVC ||
model->param.svm_type == NU_SVR ||
model->param.svm_type == ONE_CLASS){
Local<String> str_nu = Nan::New<String>("nu").ToLocalChecked();
parameters->Set(str_nu, Nan::New<Number>(model->param.nu));
}
if (model->param.svm_type == EPSILON_SVR){
Local<String> str_epsilon = Nan::New<String>("epsilon").ToLocalChecked();
parameters->Set(str_epsilon, Nan::New<Number>(model->param.p));
}
Local<String> str_kernel_type = Nan::New<String>("kernelType").ToLocalChecked();
parameters->Set(str_kernel_type, Nan::New<Number>(model->param.kernel_type));
if (model->param.kernel_type == POLY){
Local<String> str_degree = Nan::New<String>("degree").ToLocalChecked();
parameters->Set(str_degree, Nan::New<Number>(model->param.degree)); /* for poly */
}
if (model->param.kernel_type == POLY ||
model->param.kernel_type == RBF ||
model->param.kernel_type == SIGMOID){
Local<String> str_gamma = Nan::New<String>("gamma").ToLocalChecked();
parameters->Set(str_gamma, Nan::New<Number>(model->param.gamma));
}
if (model->param.kernel_type == POLY ||
model->param.kernel_type == SIGMOID){
Local<String> str_r = Nan::New<String>("r").ToLocalChecked();
parameters->Set(str_r, Nan::New<Number>(model->param.coef0));
}
// Handle<Array> weightLabels = NanNew<Array>(model->param.nr_weight);
// Handle<Array> weights = NanNew<Array>(model->param.nr_weight);
// for (int i=0 ; i < model->param.nr_weight ; i++){
// weightLabels->Set(i, NanNew<Number>(model->param.weight_label[i]));
// weights->Set(i, NanNew<Number>(model->param.weight[i]));
// }
// parameters->Set(NanNew<String>("weightLabels"), weightLabels);
// parameters->Set(NanNew<String>("weights"), weights);
Local<String> str_cache_size = Nan::New<String>("cacheSize").ToLocalChecked();
Local<String> str_eps = Nan::New<String>("eps").ToLocalChecked();
parameters->Set(str_cache_size, Nan::New<Number>(model->param.cache_size));
parameters->Set(str_eps, Nan::New<Number>(model->param.eps));
Local<String> str_shrinking = Nan::New<String>("shrinking").ToLocalChecked();
if (model->param.shrinking == 1){
parameters->Set(str_shrinking, Nan::True());
}
else {
parameters->Set(str_shrinking, Nan::False());
}
Local<String> str_probability = Nan::New<String>("probability").ToLocalChecked();
if (model->param.probability == 1){
parameters->Set(str_probability, Nan::True());
}
else {
parameters->Set(str_probability, Nan::False());
}
Local<String> str_params = Nan::New<String>("params").ToLocalChecked();
obj->Set(str_params, parameters);
return obj;
};
private:
~NodeSvm();
struct svm_parameter *params;
struct svm_model *model;
struct svm_problem *trainingProblem;
static Nan::Persistent<Function> constructor;
};
#endif /* _NODE_SVM_H */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment