Last active
December 13, 2024 13:26
Faster DecisionTreeRegressor for use with Ngboost
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numba import njit, prange | |
from sklearn.base import BaseEstimator, RegressorMixin | |
@njit(inline='always') | |
def weighted_variance_from_sums(sum_w, sum_wy, sum_wy_sq): | |
# Weighted variance: var_w = (sum_wy_sq / sum_w) - (sum_wy / sum_w)**2 | |
if sum_w <= 1e-14: | |
return 0.0 | |
mean_w = sum_wy / sum_w | |
return (sum_wy_sq / sum_w) - mean_w * mean_w | |
@njit(parallel=True) | |
def find_best_split(X, y, sample_weight, current_samples): | |
n_samples = len(current_samples) | |
n_features = X.shape[1] | |
# Compute weighted sums for the current node | |
sum_w = 0.0 | |
sum_wy = 0.0 | |
sum_wy_sq = 0.0 | |
for i in range(n_samples): | |
idx = current_samples[i] | |
w = sample_weight[idx] | |
yw = y[idx] | |
sum_w += w | |
sum_wy += w * yw | |
sum_wy_sq += w * yw * yw | |
current_var = weighted_variance_from_sums(sum_w, sum_wy, sum_wy_sq) | |
if current_var <= 1e-14: | |
# No need to split | |
return -1, 0.0, -np.inf | |
# We'll store the best result found by each feature | |
best_features = np.full(n_features, -1, dtype=np.int32) | |
best_thresholds = np.full(n_features, 0.0, dtype=np.float64) | |
best_improvements = np.full(n_features, -np.inf, dtype=np.float64) | |
# Parallel loop over features | |
for feature in prange(n_features): | |
feature_values = X[current_samples, feature] | |
sorted_idx = np.argsort(feature_values) | |
# Precompute prefix sums for this feature | |
prefix_w = np.zeros(n_samples+1, dtype=np.float64) | |
prefix_wy = np.zeros(n_samples+1, dtype=np.float64) | |
prefix_wy_sq = np.zeros(n_samples+1, dtype=np.float64) | |
for i in range(n_samples): | |
idx = current_samples[sorted_idx[i]] | |
w = sample_weight[idx] | |
yw = y[idx] | |
prefix_w[i+1] = prefix_w[i] + w | |
prefix_wy[i+1] = prefix_wy[i] + w * yw | |
prefix_wy_sq[i+1] = prefix_wy_sq[i] + w * yw * yw | |
sorted_x = feature_values[sorted_idx] | |
local_best_improvement = -np.inf | |
local_best_threshold = 0.0 | |
for i in range(1, n_samples): | |
# Only consider a split if it separates distinct values | |
if sorted_x[i] == sorted_x[i-1]: | |
continue | |
sum_w_left = prefix_w[i] | |
sum_wy_left = prefix_wy[i] | |
sum_wy_sq_left = prefix_wy_sq[i] | |
sum_w_right = sum_w - sum_w_left | |
sum_wy_right = sum_wy - sum_wy_left | |
sum_wy_sq_right = sum_wy_sq - sum_wy_sq_left | |
var_left = weighted_variance_from_sums(sum_w_left, sum_wy_left, sum_wy_sq_left) | |
var_right = weighted_variance_from_sums(sum_w_right, sum_wy_right, sum_wy_sq_right) | |
# Compute weighted improvement | |
improvement = current_var - ((sum_w_left * var_left + sum_w_right * var_right) / sum_w) | |
if improvement > local_best_improvement: | |
local_best_improvement = improvement | |
local_best_threshold = 0.5 * (sorted_x[i-1] + sorted_x[i]) | |
best_features[feature] = feature | |
best_thresholds[feature] = local_best_threshold | |
best_improvements[feature] = local_best_improvement | |
# Reduce results from all features | |
best_improvement = -np.inf | |
best_feature = -1 | |
best_threshold = 0.0 | |
for f in range(n_features): | |
if best_improvements[f] > best_improvement: | |
best_improvement = best_improvements[f] | |
best_threshold = best_thresholds[f] | |
best_feature = best_features[f] | |
return best_feature, best_threshold, best_improvement | |
@njit | |
def build_tree(X, y, sample_weight, max_depth, min_samples_split): | |
"""Build the decision tree using iterative DFS with weighted samples.""" | |
max_nodes = 2**(max_depth+1) - 1 | |
n_samples = X.shape[0] | |
# Tree structure arrays | |
is_leaf = np.zeros(max_nodes, dtype=np.int32) | |
features = np.full(max_nodes, -1, dtype=np.int32) | |
thresholds = np.full(max_nodes, np.inf, dtype=np.float64) | |
values = np.zeros(max_nodes, dtype=np.float64) | |
left_children = np.full(max_nodes, -1, dtype=np.int32) | |
right_children = np.full(max_nodes, -1, dtype=np.int32) | |
# Arrays for node stack | |
node_stack = np.zeros(max_nodes, dtype=np.int32) | |
sample_start = np.zeros(max_nodes, dtype=np.int32) | |
sample_end = np.zeros(max_nodes, dtype=np.int32) | |
depth_stack = np.zeros(max_nodes, dtype=np.int32) | |
# Sample indices array | |
sample_indices = np.arange(n_samples) | |
# Initialize root | |
stack_size = 1 | |
node_stack[0] = 0 | |
sample_start[0] = 0 | |
sample_end[0] = n_samples | |
depth_stack[0] = 0 | |
while stack_size > 0: | |
# Pop from stack | |
stack_size -= 1 | |
node_idx = node_stack[stack_size] | |
start_idx = sample_start[stack_size] | |
end_idx = sample_end[stack_size] | |
depth = depth_stack[stack_size] | |
n_node_samples = end_idx - start_idx | |
current_samples = sample_indices[start_idx:end_idx] | |
# Compute weighted sums at node | |
sum_w = 0.0 | |
sum_wy = 0.0 | |
for i in range(n_node_samples): | |
idx = current_samples[i] | |
w = sample_weight[idx] | |
yw = y[idx] | |
sum_w += w | |
sum_wy += w * yw | |
# Check stopping criteria | |
if (depth >= max_depth) or (n_node_samples < min_samples_split): | |
# Leaf node prediction = weighted mean | |
values[node_idx] = sum_wy / sum_w if sum_w > 0 else 0.0 | |
is_leaf[node_idx] = 1 | |
continue | |
# Find best split | |
feature, threshold, improvement = find_best_split(X, y, sample_weight, current_samples) | |
if feature == -1 or improvement <= 0: | |
# No split | |
values[node_idx] = sum_wy / sum_w if sum_w > 0 else 0.0 | |
is_leaf[node_idx] = 1 | |
continue | |
# Partition samples | |
features[node_idx] = feature | |
thresholds[node_idx] = threshold | |
partition_pos = start_idx | |
for i in range(start_idx, end_idx): | |
if X[sample_indices[i], feature] <= threshold: | |
tmp = sample_indices[i] | |
sample_indices[i] = sample_indices[partition_pos] | |
sample_indices[partition_pos] = tmp | |
partition_pos += 1 | |
left_idx = 2 * node_idx + 1 | |
right_idx = 2 * node_idx + 2 | |
left_children[node_idx] = left_idx | |
right_children[node_idx] = right_idx | |
# Push right child | |
node_stack[stack_size] = right_idx | |
sample_start[stack_size] = partition_pos | |
sample_end[stack_size] = end_idx | |
depth_stack[stack_size] = depth + 1 | |
stack_size += 1 | |
# Push left child | |
node_stack[stack_size] = left_idx | |
sample_start[stack_size] = start_idx | |
sample_end[stack_size] = partition_pos | |
depth_stack[stack_size] = depth + 1 | |
stack_size += 1 | |
return is_leaf, features, thresholds, values, left_children, right_children | |
@njit(parallel=True) | |
def predict_all(X, is_leaf, features, thresholds, values, left_children, right_children): | |
"""Make predictions for all samples in parallel.""" | |
n_samples = X.shape[0] | |
predictions = np.empty(n_samples, dtype=np.float64) | |
for i in prange(n_samples): | |
node_idx = 0 | |
while not is_leaf[node_idx]: | |
if X[i, features[node_idx]] <= thresholds[node_idx]: | |
node_idx = left_children[node_idx] | |
else: | |
node_idx = right_children[node_idx] | |
predictions[i] = values[node_idx] | |
return predictions | |
class NumbaDecisionTreeRegressor(BaseEstimator, RegressorMixin): | |
"""A decision tree regressor implemented with Numba and prefix-sum optimization, with sample weights.""" | |
def __init__(self, max_depth=5, min_samples_split=2): | |
self.max_depth = max_depth | |
self.min_samples_split = min_samples_split | |
self.is_leaf = None | |
self.features = None | |
self.thresholds = None | |
self.values = None | |
self.left_children = None | |
self.right_children = None | |
def fit(self, X, y, sample_weight=None): | |
X = np.asarray(X, dtype=np.float64) | |
y = np.asarray(y, dtype=np.float64) | |
if sample_weight is None: | |
# If no sample_weight is given, use equal weights. | |
sample_weight = np.ones(len(y), dtype=np.float64) | |
else: | |
sample_weight = np.asarray(sample_weight, dtype=np.float64) | |
(self.is_leaf, | |
self.features, | |
self.thresholds, | |
self.values, | |
self.left_children, | |
self.right_children) = build_tree(X, y, sample_weight, self.max_depth, self.min_samples_split) | |
return self | |
def predict(self, X): | |
X = np.asarray(X, dtype=np.float64) | |
return predict_all(X, self.is_leaf, self.features, self.thresholds, self.values, | |
self.left_children, self.right_children) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment