Created
April 21, 2013 14:41
-
-
Save y-tag/5429829 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdio> | |
#include <cstdlib> | |
#include <cfloat> | |
#include <cstring> | |
#include <iostream> | |
#include <string> | |
#include <utility> | |
#include <vector> | |
#include <map> | |
#include <unordered_map> | |
#include <fstream> | |
#include <jubatus/client.hpp> | |
// g++ -O2 -o eval_classifier2 eval_classifier2.cpp `pkg-config pficommon --libs --cflags` -lmsgpack -ljubatus_mpio -ljubatus_msgpack-rpc -std=c++0x | |
using jubatus::classifier::datum; | |
using jubatus::classifier::estimate_result; | |
int parse_line(const std::string &line, float *relevance, int *qid, datum *d) { | |
d->string_values.clear(); | |
d->num_values.clear(); | |
char cbuff[line.size() + 1]; | |
memmove(cbuff, line.c_str(), line.size() + 1); | |
char *p = strtok(cbuff, " \t"); | |
*relevance = static_cast<float>(strtod(p, NULL)); | |
while (1) { | |
char *f = strtok(NULL, ":"); | |
char *v = strtok(NULL, " \t"); | |
if (v == NULL) { | |
break; | |
} | |
if (std::string(f).substr(0, 3) == "qid") { | |
*qid = static_cast<int>(strtol(v, NULL, 10)); | |
} else { | |
d->num_values.push_back(std::make_pair(std::string(f), strtod(v, NULL))); | |
} | |
} | |
return 1; | |
} | |
int main(int argc, char **argv) { | |
std::string host = "127.0.0.1"; | |
int port = 9199; | |
std::string name = "test"; | |
jubatus::classifier::client::classifier client(host, port, 10.0); | |
client.clear(name); | |
if (argc < 5) { | |
fprintf(stderr, "%s train_in valid_in test_in valid_out test_out\n", argv[0]); | |
exit(1); | |
} | |
const char *train_in = argv[1]; | |
const char *valid_in = argv[2]; | |
const char *test_in = argv[3]; | |
const char *valid_out = argv[4]; | |
const char *test_out = argv[5]; | |
std::string buff; | |
std::map<int, std::vector<std::pair<float, datum> > > train_data; | |
std::ifstream trifs; | |
fprintf(stderr, "read train data...\n"); | |
trifs.open(train_in); | |
while (getline(trifs, buff)) { | |
float rel; | |
int qid; | |
datum d; | |
parse_line(buff, &rel, &qid, &d); | |
train_data[qid].push_back(std::make_pair(rel, d)); | |
} | |
srand(1000); | |
std::vector<std::vector<std::pair<float, datum> > > train_data_vec; | |
for (auto train_itr = train_data.begin(); train_itr != train_data.end(); ++train_itr) { | |
const std::vector<std::pair<float, datum> > &qid_data = train_itr->second; | |
train_data_vec.push_back(qid_data); | |
} | |
fprintf(stderr, "done\n"); | |
fprintf(stderr, "train start...\n"); | |
std::vector<std::pair<std::string, datum> > tmp_vec; | |
int n = 0; | |
while (n < 100000) { | |
tmp_vec.clear(); | |
for (int i = 0; i < 100; ++i) { | |
int qid = rand() % train_data_vec.size(); | |
int j = rand() % train_data_vec[qid].size(); | |
tmp_vec.push_back(std::make_pair("+", train_data_vec[qid][j].second)); | |
++n; | |
} | |
client.train(name, tmp_vec); | |
} | |
fprintf(stderr, "done\n"); | |
std::vector<datum> valid_data; | |
std::ifstream vaifs; | |
std::ofstream vaofs; | |
fprintf(stderr, "validation start...\n"); | |
vaifs.open(valid_in); | |
vaofs.open(valid_out); | |
while (getline(vaifs, buff)) { | |
float rel; | |
int qid; | |
datum d; | |
parse_line(buff, &rel, &qid, &d); | |
valid_data.push_back(d); | |
auto results = client.classify(name, valid_data); | |
double predicts = 0.0; | |
for (size_t i = 0; i < results[0].size(); ++i) { | |
const estimate_result& r = results[0][i]; | |
if (r.label == "+") { | |
predicts = r.score; | |
break; | |
} | |
} | |
vaofs << predicts << std::endl; | |
valid_data.clear(); | |
} | |
fprintf(stderr, "done\n"); | |
std::vector<datum> test_data; | |
std::ifstream teifs; | |
std::ofstream teofs; | |
fprintf(stderr, "test start...\n"); | |
teifs.open(test_in); | |
teofs.open(test_out); | |
while (getline(teifs, buff)) { | |
float rel; | |
int qid; | |
datum d; | |
parse_line(buff, &rel, &qid, &d); | |
test_data.push_back(d); | |
auto results = client.classify(name, test_data); | |
double predicts = 0.0; | |
for (size_t i = 0; i < results[0].size(); ++i) { | |
const estimate_result& r = results[0][i]; | |
if (r.label == "+") { | |
predicts = r.score; | |
break; | |
} | |
} | |
teofs << predicts << std::endl; | |
test_data.clear(); | |
} | |
fprintf(stderr, "done\n"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment