Skip to content

Instantly share code, notes, and snippets.

@yupbank
Created January 3, 2019 19:47
Show Gist options
  • Save yupbank/1e0c2f50d5ed571e10559a681e7bb76f to your computer and use it in GitHub Desktop.
Save yupbank/1e0c2f50d5ed571e10559a681e7bb76f to your computer and use it in GitHub Desktop.
import numpy as np
import time
def timeit(func):
def _(*args, **kwargs):
start = time.time()
res = func(*args, **kwargs)
end = time.time() - start
print('Func: %s, runtime: %.6f' % (func.__name__, end))
return res
return _
@timeit
def sklearn_inference(data, clf):
return clf.apply(data)
@timeit
def inference(data, clf):
feature, threshold, left, right = clf.tree_.feature, clf.tree_.threshold, clf.tree_.children_left, clf.tree_.children_right
auxilary = np.arange(data.shape[0])
prev_node = [0]
while 1:
condition = data[auxilary, feature[prev_node]] <= threshold[prev_node]
potential_next_node = np.where(
condition, left[prev_node], right[prev_node])
potential_condition = potential_next_node != -1
if not np.any(potential_condition):
break
next_node = np.where(potential_condition,
potential_next_node, prev_node)
prev_node = next_node
return prev_node
if __name__ == "__main__":
from sklearn.datasets import load_boston
from sklearn.tree import DecisionTreeRegressor
data = load_boston()
x, y = data['data'], data['target']
clf = DecisionTreeRegressor(random_state=10)
clf.fit(x, y)
leafs_mine = inference(x, clf)
leafs_sklearn = sklearn_inference(x, clf)
np.testing.assert_allclose(leafs_mine, leafs_sklearn)
@yupbank
Copy link
Author

yupbank commented Jan 3, 2019

Still ~3x times slower than the cython version, but should be much friendly with parallelization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment