Skip to content

Instantly share code, notes, and snippets.

@benwtrent
Last active February 11, 2020 17:53
Show Gist options
  • Save benwtrent/fd47533fdcdd757c4abb4279fde6be15 to your computer and use it in GitHub Desktop.
Save benwtrent/fd47533fdcdd757c4abb4279fde6be15 to your computer and use it in GitHub Desktop.
model transformers
from typing import List
def add_if_exists(d: dict, k: str, v) -> dict:
"""
:param v:
:param k:
:type d: object
"""
if v is not None:
d[k] = v
return d
class ESModel:
def __init__(self,
feature_names: List[str],
target_type: str = None,
classification_labels: List[str] = None):
self.target_type = target_type
self.feature_names = feature_names
self.classification_labels = classification_labels
def to_dict(self):
d = dict()
add_if_exists(d, "target_type", self.target_type)
add_if_exists(d, "feature_names", self.feature_names)
add_if_exists(d, "classification_labels", self.classification_labels)
return d
class TreeNode:
def __init__(self,
node_idx: int,
default_left: bool = None,
decision_type: str = None,
left_child: int = None,
right_child: int = None,
split_feature: int = None,
threshold: float = None,
leaf_value: float = None):
self.node_idx = node_idx
self.decision_type = decision_type
self.left_child = left_child
self.right_child = right_child
self.split_feature = split_feature
self.threshold = threshold
self.leaf_value = leaf_value
self.default_left = default_left
def to_dict(self):
d = dict()
add_if_exists(d, 'node_index', self.node_idx)
add_if_exists(d, 'decision_type', self.decision_type)
if self.leaf_value is None:
add_if_exists(d, 'left_child', self.left_child)
add_if_exists(d, 'right_child', self.right_child)
add_if_exists(d, 'split_feature', self.split_feature)
add_if_exists(d, 'threshold', self.threshold)
else:
add_if_exists(d, 'leaf_value', self.leaf_value)
return d
class Tree(ESModel):
def __init__(self,
feature_names: List[str],
target_type: str = None,
tree_structure: List[TreeNode] = [],
classification_labels: List[str] = None):
super().__init__(
feature_names=feature_names,
target_type=target_type,
classification_labels=classification_labels
)
if target_type == 'regression' and classification_labels:
raise ValueError("regression does not support classification_labels")
self.tree_structure = tree_structure
def to_dict(self):
d = super().to_dict()
add_if_exists(d, 'tree_structure', [t.to_dict() for t in self.tree_structure])
return {'tree': d}
class Ensemble(ESModel):
def __init__(self,
feature_names: List[str],
trained_models: List[ESModel],
output_aggregator: dict,
target_type: str = None,
classification_labels: List[str] = None,
classification_weights: List[float] = None):
super().__init__(feature_names=feature_names,
target_type=target_type,
classification_labels=classification_labels)
self.trained_models = trained_models
self.classification_weights = classification_weights
self.output_aggregator = output_aggregator
def to_dict(self):
d = super().to_dict()
trained_models = None
if self.trained_models:
trained_models = [t.to_dict() for t in self.trained_models]
add_if_exists(d, 'trained_models', trained_models)
add_if_exists(d, 'classification_weights', self.classification_weights)
add_if_exists(d, 'aggregate_output', self.output_aggregator)
return {'ensemble': d}
from sklearn import datasets
from ModelTransformers import *
from Utils import serialize_and_compress_model
classification_data = datasets.make_classification(n_features=5)
regression_data = datasets.make_regression(n_features=5)
print(classification_data[0][0:5])
print(regression_data[0][0:5])
from elasticsearch import Elasticsearch
from ESModel import ESModel
from elasticsearch.exceptions import NotFoundError
es = Elasticsearch()
def delete_model(elastic: Elasticsearch, name):
try:
elastic.transport.perform_request("DELETE", "/_ml/inference/" + name)
except NotFoundError:
print("not found {0}".format(name))
def put_model(elastic: Elasticsearch, model: ESModel, name):
m = str(serialize_and_compress_model(model))[2:-1] # remove `b` and str quotes
elastic.transport.perform_request(
"PUT", "/_ml/inference/" + name,
body={
"input": {
"field_names": ["f0", "f1", "f2", "f3", "f4"]
},
"compressed_definition": m
}
)
def infer_model(elastic: Elasticsearch, name, type):
return elastic.transport.perform_request(
"POST",
"/_ingest/pipeline/_simulate",
body={
"pipeline": {
"processors": [
{"inference": {
"model_id": name,
"inference_config": type,
"field_mappings": {}
}}
]
},
"docs": [
{
"_source": {
"f0": 0.1,
"f1": 0.2,
"f2": 0.3,
"f3": -0.5,
"f4": 1.0
}
},
{
"_source": {
"f0": 1.6,
"f1": 2.1,
"f2": -10,
"f3": 50,
"f4": -1.0
}
}
]
})
delete_model(es, "sklearn_class_model_1")
print("\n\nSKLEARN CLASSIFIIER")
random_forest_classifier = RandomForestClassifier()
random_forest_classifier.fit(classification_data[0], classification_data[1])
print(random_forest_classifier.predict([[0.1, 0.2, 0.3, -0.5, 1.0],
[1.6, 2.1, -10, 50, -1.0]]))
put_model(es, SKLearnForestClassifierTransformer(random_forest_classifier,
["f0", "f1", "f2", "f3", "f4"]).transform(), "sklearn_class_model_1")
inference = infer_model(es, "sklearn_class_model_1", {"classification": {}})
print(inference)
print([d['doc']['_source']['ml']['inference']['predicted_value'] for d in inference['docs']])
print("\n\nSKLEARN REGRESSOR")
delete_model(es, "sklearn_reg_model_1")
random_forest_regressor = RandomForestRegressor()
random_forest_regressor.fit(regression_data[0], regression_data[1])
print(random_forest_regressor.predict([[0.1, 0.2, 0.3, -0.5, 1.0],
[1.6, 2.1, -10, 50, -1.0]]))
put_model(es, SKLearnForestRegressorTransformer(random_forest_regressor,
["f0", "f1", "f2", "f3", "f4"]).transform(), "sklearn_reg_model_1")
inference = infer_model(es, "sklearn_reg_model_1", {"regression": {}})
print(inference)
print([d['doc']['_source']['ml']['inference']['predicted_value'] for d in inference['docs']])
print("\n\nXGBOOST CLASSIFIIER")
delete_model(es, "xgboost_class")
xgboost_classifier = XGBClassifier()
xgboost_classifier.fit(classification_data[0], classification_data[1])
print(xgboost_classifier.predict([[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]))
print(xgboost_classifier.predict_proba([[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]))
put_model(es, SKLearnXGBoostClassifierTransformer(xgboost_classifier,
["f0", "f1", "f2", "f3", "f4"]).transform(), "xgboost_class")
inference = infer_model(es, "xgboost_class", {"classification": {"num_top_classes": 2}})
print(inference)
print([d['doc']['_source']['ml']['inference']['predicted_value'] for d in inference['docs']])
print("\n\nXGBOOST REGRESSOR")
delete_model(es, "xgboost_reg")
xgboost_regressor = XGBRegressor()
xgboost_regressor.fit(regression_data[0], regression_data[1])
print(xgboost_regressor.predict([[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]))
put_model(es, SKLearnXGBoostRegressorTransformer(xgboost_regressor,
["f0", "f1", "f2", "f3", "f4"]).transform(), "xgboost_reg")
inference = infer_model(es, "xgboost_reg", {"regression": {}})
print(inference)
print([d['doc']['_source']['ml']['inference']['predicted_value'] for d in inference['docs']])
import numpy as np
from typing import List, Union
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from ESModel import Tree, TreeNode, Ensemble
from sklearn.utils.validation import check_is_fitted
from xgboost import Booster, XGBRegressor, XGBClassifier
class ModelTransformer:
def __init__(self,
model,
feature_names: List[str],
classification_labels: List[str] = None,
classification_weights: List[float] = None
):
self.feature_names = feature_names
self.model = model
self.classification_labels = classification_labels
self.classification_weights = classification_weights
def is_supported(self):
return isinstance(self.model, (DecisionTreeClassifier,
DecisionTreeRegressor,
RandomForestRegressor,
RandomForestClassifier,
XGBClassifier,
XGBRegressor,
Booster))
class SKLearnTransformer(ModelTransformer):
"""
Base class for SKLearn transformers.
warning: Should not use this class directly. Use derived classes instead
"""
def __init__(self,
model,
feature_names: List[str],
classification_labels: List[str] = None,
classification_weights: List[float] = None
):
"""
Base class for SKLearn transformations
:param model: sklearn trained model
:param feature_names: The feature names for the model
:param classification_labels: Optional classification labels (if not encoded in the model)
:param classification_weights: Optional classification weights
"""
super().__init__(model, feature_names, classification_labels, classification_weights)
self.node_decision_type = "lte"
def build_tree_node(self, node_index: int, node_data: dict, value) -> TreeNode:
"""
This builds out a TreeNode class given the sklearn tree node definition.
Node decision types are defaulted to "lte" to match the behavior of SKLearn
:param node_index: The node index
:param node_data: Opaque node data contained in the sklearn tree state
:param value: Opaque node value (i.e. leaf/node values) from tree state
:return: TreeNode object
"""
if value.shape[0] != 1:
raise ValueError("unexpected multiple values returned from leaf node '{0}'".format(node_index))
if node_data[0] == -1: # is leaf node
if value.shape[1] == 1: # classification requires more than one value, so assume regression
leaf_value = float(value[0][0])
else:
# the classification value, which is the index of the largest value
leaf_value = int(np.argmax(value))
return TreeNode(node_index, decision_type=self.node_decision_type, leaf_value=leaf_value)
else:
return TreeNode(node_index,
decision_type=self.node_decision_type,
left_child=int(node_data[0]),
right_child=int(node_data[1]),
split_feature=int(node_data[2]),
threshold=float(node_data[3]))
class SKLearnDecisionTreeTransformer(SKLearnTransformer):
"""
class for transforming SKLearn decision tree models into Tree model formats supported by Elasticsearch.
"""
def __init__(self,
model: Union[DecisionTreeRegressor, DecisionTreeClassifier],
feature_names: List[str],
classification_labels: List[str] = None):
"""
Transforms a Decision Tree model (Regressor|Classifier) into a ES Supported Tree format
:param model: fitted decision tree model
:param feature_names: model feature names
:param classification_labels: Optional classification labels
"""
super().__init__(model, feature_names, classification_labels)
def transform(self) -> Tree:
"""
Transform the provided model into an ES supported Tree object
:return: Tree object for ES storage and use
"""
target_type = "regression" if isinstance(self.model, DecisionTreeRegressor) else "classification"
check_is_fitted(self.model, ["tree_"])
tree_classes = None
if self.classification_labels:
tree_classes = self.classification_labels
if isinstance(self.model, DecisionTreeClassifier):
check_is_fitted(self.model, ["classes_"])
if tree_classes is None:
tree_classes = [str(c) for c in self.model.classes_]
nodes = list()
tree_state = self.model.tree_.__getstate__()
for i in range(len(tree_state["nodes"])):
nodes.append(self.build_tree_node(i, tree_state["nodes"][i], tree_state["values"][i]))
return Tree(self.feature_names,
target_type,
nodes,
tree_classes)
class SKLearnForestTransformer(SKLearnTransformer):
"""
Base class for transforming SKLearn forest models into Ensemble model formats supported by Elasticsearch.
warning: do not use this class directly. Use a derived class instead
"""
def __init__(self,
model: Union[RandomForestClassifier,
RandomForestRegressor],
feature_names: List[str],
classification_labels: List[str] = None,
classification_weights: List[float] = None
):
super().__init__(model, feature_names, classification_labels, classification_weights)
def build_aggregator_output(self) -> dict:
raise NotImplementedError("build_aggregator_output must be implemented")
def determine_target_type(self) -> str:
raise NotImplementedError("determine_target_type must be implemented")
def transform(self) -> Ensemble:
check_is_fitted(self.model, ["estimators_"])
estimators = self.model.estimators_
ensemble_classes = None
if self.classification_labels:
ensemble_classes = self.classification_labels
if isinstance(self.model, RandomForestClassifier):
check_is_fitted(self.model, ["classes_"])
if ensemble_classes is None:
ensemble_classes = [str(c) for c in self.model.classes_]
ensemble_models = [SKLearnDecisionTreeTransformer(m,
self.feature_names).transform() for m in estimators]
return Ensemble(self.feature_names,
ensemble_models,
self.build_aggregator_output(),
target_type=self.determine_target_type(),
classification_labels=ensemble_classes,
classification_weights=self.classification_weights)
class SKLearnForestRegressorTransformer(SKLearnForestTransformer):
"""
Class for transforming RandomForestRegressor models into an ensemble model supported by Elasticsearch
"""
def __init__(self,
model: RandomForestRegressor,
feature_names: List[str]
):
super().__init__(model, feature_names)
def build_aggregator_output(self) -> dict:
return {
"weighted_sum": {"weights": [1.0 / len(self.model.estimators_)] * len(self.model.estimators_), }
}
def determine_target_type(self) -> str:
return "regression"
class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
"""
Class for transforming RandomForestClassifier models into an ensemble model supported by Elasticsearch
"""
def __init__(self,
model: RandomForestClassifier,
feature_names: List[str],
classification_labels: List[str] = None,
):
super().__init__(model, feature_names, classification_labels)
def build_aggregator_output(self) -> dict:
return {"weighted_mode": {"num_classes": len(self.model.classes_)}}
def determine_target_type(self) -> str:
return "classification"
class XGBoostForestTransformer(ModelTransformer):
"""
Base class for transforming XGBoost models into ensemble models supported by Elasticsearch
warning: do not use directly. Use a derived classes instead
"""
def __init__(self,
model: Booster,
feature_names: List[str],
base_score: float = 0.5,
objective: str = "reg:squarederror",
classification_labels: List[str] = None,
classification_weights: List[float] = None
):
super().__init__(model, feature_names, classification_labels, classification_weights)
self.node_decision_type = "lt"
self.base_score = base_score
self.objective = objective
def get_feature_id(self, feature_id: str) -> int:
if feature_id[0] == "f":
try:
return int(feature_id[1:])
except ValueError:
raise RuntimeError("Unable to interpret '{0}'".format(feature_id))
else:
try:
return int(feature_id)
except ValueError:
raise RuntimeError("Unable to interpret '{0}'".format(feature_id))
def extract_node_id(self, node_id: str, curr_tree: int) -> int:
t_id, n_id = node_id.split("-")
if t_id is None or n_id is None:
raise RuntimeError(
"cannot determine node index or tree from '{0}' for tree {1}".format(node_id, curr_tree))
try:
t_id = int(t_id)
n_id = int(n_id)
if t_id != curr_tree:
raise RuntimeError("extracted tree id {0} does not match current tree {1}".format(t_id, curr_tree))
return n_id
except ValueError:
raise RuntimeError(
"cannot determine node index or tree from '{0}' for tree {1}".format(node_id, curr_tree))
def build_tree_node(self, row, curr_tree: int) -> TreeNode:
node_index = row["Node"]
if row["Feature"] == "Leaf":
return TreeNode(node_idx=node_index, leaf_value=float(row["Gain"]))
else:
return TreeNode(node_idx=node_index,
decision_type=self.node_decision_type,
left_child=self.extract_node_id(row["Yes"], curr_tree),
right_child=self.extract_node_id(row["No"], curr_tree),
threshold=float(row["Split"]),
split_feature=self.get_feature_id(row["Feature"]))
def build_tree(self, nodes: List[TreeNode]) -> Tree:
return Tree(feature_names=self.feature_names,
tree_structure=nodes)
def build_base_score_stump(self) -> Tree:
return Tree(feature_names=self.feature_names,
tree_structure=[TreeNode(0, leaf_value=self.base_score)])
def build_forest(self) -> List[Tree]:
"""
This builds out the forest of trees as described by XGBoost into a format
supported by Elasticsearch
:return: A list of Tree objects
"""
if self.model.booster not in {'dart', 'gbtree'}:
raise ValueError("booster must exist and be of type dart or gbtree")
tree_table = self.model.trees_to_dataframe()
transformed_trees = list()
curr_tree = None
tree_nodes = list()
for _, row in tree_table.iterrows():
if row["Tree"] != curr_tree:
if len(tree_nodes) > 0:
transformed_trees.append(self.build_tree(tree_nodes))
curr_tree = row["Tree"]
tree_nodes = list()
tree_nodes.append(self.build_tree_node(row, curr_tree))
# add last tree
if len(tree_nodes) > 0:
transformed_trees.append(self.build_tree(tree_nodes))
# We add this stump as XGBoost adds the base_score to the regression outputs
if self.objective.startswith("reg"):
transformed_trees.append(self.build_base_score_stump())
return transformed_trees
def build_aggregator_output(self) -> dict:
raise NotImplementedError("build_aggregator_output must be implemented")
def determine_target_type(self) -> str:
raise NotImplementedError("determine_target_type must be implemented")
def is_objective_supported(self) -> bool:
return False
def transform(self) -> Ensemble:
if self.model.booster not in {'dart', 'gbtree'}:
raise ValueError("booster must exist and be of type dart or gbtree")
if not self.is_objective_supported():
raise ValueError("Unsupported objective '{0}'".format(self.objective))
forest = self.build_forest()
return Ensemble(feature_names=self.feature_names,
trained_models=forest,
output_aggregator=self.build_aggregator_output(),
classification_labels=self.classification_labels,
classification_weights=self.classification_weights,
target_type=self.determine_target_type())
class SKLearnXGBoostRegressorTransformer(XGBoostForestTransformer):
def __init__(self,
model: XGBRegressor,
feature_names: List[str]):
super().__init__(model.get_booster(),
feature_names,
model.base_score,
model.objective)
def determine_target_type(self) -> str:
return "regression"
def is_objective_supported(self) -> bool:
return self.objective in {'reg:squarederror',
'reg:linear',
'reg:squaredlogerror',
'reg:logistic'}
def build_aggregator_output(self) -> dict:
return {"weighted_sum": {}}
class SKLearnXGBoostClassifierTransformer(XGBoostForestTransformer):
def __init__(self,
model: XGBClassifier,
feature_names: List[str],
classification_labels: List[str] = None):
super().__init__(model.get_booster(),
feature_names,
model.base_score,
model.objective,
classification_labels)
def determine_target_type(self) -> str:
return "classification"
def is_objective_supported(self) -> bool:
return self.objective in {'binary:logistic', 'binary:hinge'}
def build_aggregator_output(self) -> dict:
return {"logistic_regression": {}}
appnope==0.1.0
attrs==19.3.0
backcall==0.1.0
bleach==3.1.0
decorator==4.4.1
defusedxml==0.6.0
elasticsearch==7.5.1
entrypoints==0.3
importlib-metadata==1.5.0
ipykernel==5.1.4
ipython==7.12.0
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.16.0
Jinja2==2.11.1
joblib==0.14.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==5.3.4
jupyter-console==6.1.0
jupyter-core==4.6.1
MarkupSafe==1.1.1
mistune==0.8.4
nbconvert==5.6.1
nbformat==5.0.4
notebook==6.0.3
numpy==1.18.1
pandas==1.0.0
pandocfilters==1.4.2
parso==0.6.1
pexpect==4.8.0
pickleshare==0.7.5
prometheus-client==0.7.1
prompt-toolkit==3.0.3
ptyprocess==0.6.0
Pygments==2.5.2
pyrsistent==0.15.7
python-dateutil==2.8.1
pytz==2019.3
pyzmq==18.1.1
qtconsole==4.6.0
scikit-learn==0.22.1
scipy==1.4.1
Send2Trash==1.5.0
six==1.14.0
sklearn==0.0
terminado==0.8.3
testpath==0.4.4
tornado==6.0.3
traitlets==4.3.3
urllib3==1.25.8
wcwidth==0.1.8
webencodings==0.5.1
widgetsnbextension==3.5.1
xgboost==0.90
zipp==2.1.0
import gzip
import base64
import json
import ESModel
def serialize_and_compress_model(model: ESModel) -> str:
json_string = json.dumps({'trained_model': model.to_dict()})
return base64.b64encode(gzip.compress(bytes(json_string, 'utf-8')))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment