Skip to content

Instantly share code, notes, and snippets.

@wugh
Last active February 15, 2020 05:04
Show Gist options
  • Save wugh/25c156009cb02ce94d1c268ffd5c2db7 to your computer and use it in GitHub Desktop.
Save wugh/25c156009cb02ce94d1c268ffd5c2db7 to your computer and use it in GitHub Desktop.
xgboost c++ predictor
#include <fstream>
#include <iostream>
#include <vector>
#include <string>
#include <boost/regex.hpp>
#include <boost/lexical_cast.hpp>
#include <unordered_map>
#include <cmath>
struct Node {
int left_;
int right_;
int miss_;
int split_feature_;
float split_threshold_;
float value_;
// internal node
Node(int left, int right, int miss, int split_feature, float threshold):\
left_(left), right_(right), miss_(miss),\
split_feature_(split_feature), split_threshold_(threshold),\
value_(0.0){}
// leaf node
Node(float value):
left_(-1), right_(-1), miss_(-1),\
split_feature_(-1), split_threshold_(0.0), value_(value){}
};
typedef std::shared_ptr<Node> NodePtr;
class Tree {
public:
void add_line(const std::string &line) {
if (line.find("leaf") == std::string::npos) {
add_internal_node(line);
} else {
add_leaf_node(line);
}
}
float predict(std::vector<float> &inst, float missing=NAN) {
int nnode = 0;
while (tree_.find(nnode) != tree_.end()) {
auto& node_ptr = tree_[nnode];
if (node_ptr->left_ == -1) {
return node_ptr->value_;
}
if (std::isnan(missing) &&
std::isnan(inst[node_ptr->split_feature_])) {
nnode = node_ptr->miss_;
} else if (inst[node_ptr->split_feature_] == missing) {
nnode = node_ptr->miss_;
} else if (inst[node_ptr->split_feature_] <
(node_ptr->split_threshold_)) {
nnode = node_ptr->left_;
} else {
nnode = node_ptr->right_;
}
}
// if can not find leaf, some thing wrong
std::cerr << "error predict" << std::endl;
return 0.0;
}
private:
static boost::regex internal_expr_;
static boost::regex leaf_expr_;
std::unordered_map<int, NodePtr> tree_;
void add_leaf_node(const std::string &line) {
boost::smatch what;
if (boost::regex_search(line, what, leaf_expr_)) {
int id = std::stoi(what[1]);
float value = boost::lexical_cast<float>(what[2]);
NodePtr node_ptr(new Node(value));
tree_.emplace(std::make_pair(id, node_ptr));
}
}
void add_internal_node(const std::string &line) {
boost::smatch what;
if (boost::regex_search(line, what, internal_expr_)) {
int id = std::stoi(what[1]);
auto split_feature = std::stoi(what[2]);
auto threshold = boost::lexical_cast<float>(what[3]);
auto left = std::stoi(what[4]);
auto right = std::stoi(what[5]);
auto missing = std::stoi(what[6]);
NodePtr node_ptr(new Node(left, right, missing,
split_feature, threshold));
tree_.emplace(std::make_pair(id, node_ptr));
}
}
};
boost::regex Tree::internal_expr_{"(\\d+):\\[f(\\d+)<(.+)\\] *yes=(\\d+),no=(\\d+),missing=(\\d+)"};
boost::regex Tree::leaf_expr_{"(\\d+):leaf=(.+)"};
typedef std::shared_ptr<Tree> TreePtr;
class GBDT {
private:
std::vector<TreePtr> trees_;
float base_score_;
public:
GBDT(const char* filename, float base_score=0.5) {
base_score_ = base_score;
std::string line;
std::ifstream infile(filename);
// omit first line
std::getline(infile, line);
TreePtr tree_ptr;
while (std::getline(infile, line)) {
if (line.find("booster") != std::string::npos) {
if (tree_ptr == nullptr)
continue;
trees_.emplace_back(tree_ptr);
tree_ptr.reset(new Tree());
continue;
}
if (tree_ptr == nullptr)
tree_ptr.reset(new Tree());
tree_ptr->add_line(line);
}
if (tree_ptr != nullptr)
trees_.emplace_back(tree_ptr);
}
float predict(std::vector<float> &inst, float missing=NAN) {
float score = base_score_;
int num = 0;
for (auto &tree_ptr: trees_) {
score += tree_ptr->predict(inst, missing);
num += 1;
}
return score;
}
};
int main(int argc, const char *argv[])
{
std::string path = "./channel_answer_rate_model.raw";
std::shared_ptr<GBDT> gbdt_ptr(new GBDT(path.c_str()));
std::vector<float> inst{0.0, 65.000000, 63.000000, 0.954545, 31.000000, 1.000000, 0.031250, 0.094889, 0.821576, 0.133360, 0.091890, 0.000000, 0.079810, 0.038600, 0.996000, 0.004150, 0.442620, 0.378290, 0.150160, 0.015770, 0.009010, 0.008910, 3549.000000, 53.772727, 31.000000, 0.000000, 2.000000, 22.659020, 37.173530, 12.391177, 3.000000, 2.637781, 5.160167, 2.580083, 2.000000, -0.032491, 0.307508, 2.000000, 0.451931, 0.835198, 1580.000000, 117.000000, 0.074051, 69.000000, 0.589744, 34372.000000, 293.777778, 0.003530, 0.003413, 0.003720, 0.002524, 0.002546, 0.001198, 0.269883, 0.287920, 0.007814, 0.000000, 0.000000, 0.008881, 0.002032, 0.000000, 0.000000, 0.000000};
std::cout << gbdt_ptr->predict(inst) << std::endl;
return 0;
}
@harishvk27
Copy link

Why loading done with raw model format and not from binary model format?.
I want to load from model using binary format.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment