Skip to content

Instantly share code, notes, and snippets.

@ilivans
Created June 10, 2022 22:11
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save ilivans/ce0d5c766feb1d0b0c431874f182a76a to your computer and use it in GitHub Desktop.
Save ilivans/ce0d5c766feb1d0b0c431874f182a76a to your computer and use it in GitHub Desktop.
Convert XGBoost model into Solr LTR MultipleAdditiveTreesModel
import json
from xgboost import XGBModel
def dump_xgbmodel(xgb_model: XGBModel) -> dict:
"""
Dump XGBModel instance as a Solr LTR MultipleAdditiveTreesModel model.
Solr LTR MART model specification:
- https://solr.apache.org/docs/8_11_0/solr-ltr/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.html
- https://solr.apache.org/guide/8_11/learning-to-rank.html#model-evolution
:param xgb_model: Trained XGBModel; it can be either one of XGBClassifier, XGBRegressor, XGBRanker.
:return: Solr LTR MART model; the "name" key should be overwritten by a unique model name.
"""
booster = xgb_model._Booster
return {
"class": "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name": "multipleadditivetreesmodel",
"features": [{"name": feature_name} for feature_name in booster.get_score(importance_type='weight')],
"params": {
"trees": [
{
"weight": "1.0",
"root": _dump_xgb_tree(json.loads(tree))
}
for tree in booster.get_dump(dump_format='json')
]
},
}
def _dump_xgb_tree(tree: dict) -> dict:
"""
Convert XGB binary tree dump into Solr format.
See a unit test below for an example.
:param tree: XGB formatted tree.
:return: Solr formatted tree.
"""
if "leaf" in tree:
return {"value": str(tree["leaf"])}
return {
"feature": tree["split"],
"threshold": str(tree["split_condition"]),
"left": _dump_xgb_tree(tree["children"][0]),
"right": _dump_xgb_tree(tree["children"][1])
}
def test_dump_xgb_tree():
tree_xgb = {
"nodeid": 0,
"depth": 0,
"split": "feature_A",
"split_condition": 0.0042712898,
"yes": 1,
"no": 2,
"missing": 1,
"children": [
{
"nodeid": 1,
"leaf": 0.0222222228
},
{
"nodeid": 2,
"depth": 1,
"split": "feature_B",
"split_condition": 0.00859241374,
"yes": 3,
"no": 4,
"missing": 3,
"children": [
{
"nodeid": 3,
"leaf": 0.0222222228
},
{
"nodeid": 4,
"leaf": -0.142454728
}
]
}
]
}
tree_solr = {
"feature": "feature_A",
"threshold": "0.0042712898",
"left": {
"value": "0.0222222228"
},
"right": {
"feature": "feature_B",
"threshold": "0.00859241374",
"left": {
"value": "0.0222222228"
},
"right": {
"value": "-0.142454728"
}
}
}
assert _dump_xgb_tree(tree_xgb) == tree_solr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment