Skip to content

Instantly share code, notes, and snippets.

@y-tag
Created March 23, 2013 22:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save y-tag/5229592 to your computer and use it in GitHub Desktop.
Save y-tag/5229592 to your computer and use it in GitHub Desktop.
#include <cstdio>
#include <cfloat>
#include <cstring>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include <fstream>
#include <jubatus/client.hpp>
// g++ -O2 -o eval_svmdata eval_svmdata.cpp `pkg-config pficommon --libs --cflags` -lmsgpack -ljubatus_mpio -ljubatus_msgpack-rpc
using jubatus::classifier::datum;
using jubatus::classifier::estimate_result;
int parse_line(const std::string &line, std::string *label, datum *d) {
d->string_values.clear();
d->num_values.clear();
char cbuff[line.size() + 1];
memmove(cbuff, line.c_str(), line.size() + 1);
char *p = strtok(cbuff, " \t");
*label = p;
while (1) {
char *f = strtok(NULL, ":");
char *v = strtok(NULL, " \t");
if (v == NULL) {
break;
}
d->num_values.push_back(std::make_pair(std::string(f), strtod(v, NULL)));
}
return 1;
}
int main(int argc, char **argv) {
std::string host = "127.0.0.1";
int port = 9199;
std::string name = "test";
jubatus::classifier::client::classifier client(host, port, 10.0);
client.clear(name);
if (argc < 3) {
fprintf(stderr, "%s train_f test_f\n", argv[0]);
exit(1);
}
const char *train_f = argv[1];
const char *test_f = argv[2];
std::string buff;
std::vector<std::pair<std::string, datum> > train_data;
std::ifstream trfs;
fprintf(stderr, "train start...\n");
trfs.open(train_f);
while (getline(trfs, buff)) {
std::string label;
datum d;
parse_line(buff, &label, &d);
train_data.push_back(std::make_pair(label, d));
if (train_data.size() >= 1000) {
client.train(name, train_data);
train_data.clear();
}
}
if (train_data.size() > 0) {
client.train(name, train_data);
}
std::vector<datum> test_data;
std::ifstream tefs;
int num_correct = 0;
int num_all = 0;
fprintf(stderr, "test start...\n");
tefs.open(test_f);
while (getline(tefs, buff)) {
std::string label;
datum d;
parse_line(buff, &label, &d);
test_data.push_back(d);
std::vector<std::vector<estimate_result> > results = client.classify(name, test_data);
test_data.clear();
std::string predicted_label = "";
double max_score = -DBL_MAX;
for (size_t i = 0; i < results[0].size(); ++i) {
const estimate_result& r = results[0][i];
if (r.score > max_score) {
max_score = r.score;
predicted_label = r.label;
}
}
if (predicted_label == label) {
num_correct += 1;
}
num_all += 1;
}
fprintf(stdout, "%d\t%d\t%f\n", num_correct, num_all, static_cast<double>(num_correct) / num_all);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment