Skip to content

Instantly share code, notes, and snippets.

@kmedved
Last active December 13, 2024 13:26
Faster DecisionTreeRegressor for use with Ngboost
import numpy as np
from numba import njit, prange
from sklearn.base import BaseEstimator, RegressorMixin
@njit(inline='always')
def weighted_variance_from_sums(sum_w, sum_wy, sum_wy_sq):
# Weighted variance: var_w = (sum_wy_sq / sum_w) - (sum_wy / sum_w)**2
if sum_w <= 1e-14:
return 0.0
mean_w = sum_wy / sum_w
return (sum_wy_sq / sum_w) - mean_w * mean_w
@njit(parallel=True)
def find_best_split(X, y, sample_weight, current_samples):
n_samples = len(current_samples)
n_features = X.shape[1]
# Compute weighted sums for the current node
sum_w = 0.0
sum_wy = 0.0
sum_wy_sq = 0.0
for i in range(n_samples):
idx = current_samples[i]
w = sample_weight[idx]
yw = y[idx]
sum_w += w
sum_wy += w * yw
sum_wy_sq += w * yw * yw
current_var = weighted_variance_from_sums(sum_w, sum_wy, sum_wy_sq)
if current_var <= 1e-14:
# No need to split
return -1, 0.0, -np.inf
# We'll store the best result found by each feature
best_features = np.full(n_features, -1, dtype=np.int32)
best_thresholds = np.full(n_features, 0.0, dtype=np.float64)
best_improvements = np.full(n_features, -np.inf, dtype=np.float64)
# Parallel loop over features
for feature in prange(n_features):
feature_values = X[current_samples, feature]
sorted_idx = np.argsort(feature_values)
# Precompute prefix sums for this feature
prefix_w = np.zeros(n_samples+1, dtype=np.float64)
prefix_wy = np.zeros(n_samples+1, dtype=np.float64)
prefix_wy_sq = np.zeros(n_samples+1, dtype=np.float64)
for i in range(n_samples):
idx = current_samples[sorted_idx[i]]
w = sample_weight[idx]
yw = y[idx]
prefix_w[i+1] = prefix_w[i] + w
prefix_wy[i+1] = prefix_wy[i] + w * yw
prefix_wy_sq[i+1] = prefix_wy_sq[i] + w * yw * yw
sorted_x = feature_values[sorted_idx]
local_best_improvement = -np.inf
local_best_threshold = 0.0
for i in range(1, n_samples):
# Only consider a split if it separates distinct values
if sorted_x[i] == sorted_x[i-1]:
continue
sum_w_left = prefix_w[i]
sum_wy_left = prefix_wy[i]
sum_wy_sq_left = prefix_wy_sq[i]
sum_w_right = sum_w - sum_w_left
sum_wy_right = sum_wy - sum_wy_left
sum_wy_sq_right = sum_wy_sq - sum_wy_sq_left
var_left = weighted_variance_from_sums(sum_w_left, sum_wy_left, sum_wy_sq_left)
var_right = weighted_variance_from_sums(sum_w_right, sum_wy_right, sum_wy_sq_right)
# Compute weighted improvement
improvement = current_var - ((sum_w_left * var_left + sum_w_right * var_right) / sum_w)
if improvement > local_best_improvement:
local_best_improvement = improvement
local_best_threshold = 0.5 * (sorted_x[i-1] + sorted_x[i])
best_features[feature] = feature
best_thresholds[feature] = local_best_threshold
best_improvements[feature] = local_best_improvement
# Reduce results from all features
best_improvement = -np.inf
best_feature = -1
best_threshold = 0.0
for f in range(n_features):
if best_improvements[f] > best_improvement:
best_improvement = best_improvements[f]
best_threshold = best_thresholds[f]
best_feature = best_features[f]
return best_feature, best_threshold, best_improvement
@njit
def build_tree(X, y, sample_weight, max_depth, min_samples_split):
"""Build the decision tree using iterative DFS with weighted samples."""
max_nodes = 2**(max_depth+1) - 1
n_samples = X.shape[0]
# Tree structure arrays
is_leaf = np.zeros(max_nodes, dtype=np.int32)
features = np.full(max_nodes, -1, dtype=np.int32)
thresholds = np.full(max_nodes, np.inf, dtype=np.float64)
values = np.zeros(max_nodes, dtype=np.float64)
left_children = np.full(max_nodes, -1, dtype=np.int32)
right_children = np.full(max_nodes, -1, dtype=np.int32)
# Arrays for node stack
node_stack = np.zeros(max_nodes, dtype=np.int32)
sample_start = np.zeros(max_nodes, dtype=np.int32)
sample_end = np.zeros(max_nodes, dtype=np.int32)
depth_stack = np.zeros(max_nodes, dtype=np.int32)
# Sample indices array
sample_indices = np.arange(n_samples)
# Initialize root
stack_size = 1
node_stack[0] = 0
sample_start[0] = 0
sample_end[0] = n_samples
depth_stack[0] = 0
while stack_size > 0:
# Pop from stack
stack_size -= 1
node_idx = node_stack[stack_size]
start_idx = sample_start[stack_size]
end_idx = sample_end[stack_size]
depth = depth_stack[stack_size]
n_node_samples = end_idx - start_idx
current_samples = sample_indices[start_idx:end_idx]
# Compute weighted sums at node
sum_w = 0.0
sum_wy = 0.0
for i in range(n_node_samples):
idx = current_samples[i]
w = sample_weight[idx]
yw = y[idx]
sum_w += w
sum_wy += w * yw
# Check stopping criteria
if (depth >= max_depth) or (n_node_samples < min_samples_split):
# Leaf node prediction = weighted mean
values[node_idx] = sum_wy / sum_w if sum_w > 0 else 0.0
is_leaf[node_idx] = 1
continue
# Find best split
feature, threshold, improvement = find_best_split(X, y, sample_weight, current_samples)
if feature == -1 or improvement <= 0:
# No split
values[node_idx] = sum_wy / sum_w if sum_w > 0 else 0.0
is_leaf[node_idx] = 1
continue
# Partition samples
features[node_idx] = feature
thresholds[node_idx] = threshold
partition_pos = start_idx
for i in range(start_idx, end_idx):
if X[sample_indices[i], feature] <= threshold:
tmp = sample_indices[i]
sample_indices[i] = sample_indices[partition_pos]
sample_indices[partition_pos] = tmp
partition_pos += 1
left_idx = 2 * node_idx + 1
right_idx = 2 * node_idx + 2
left_children[node_idx] = left_idx
right_children[node_idx] = right_idx
# Push right child
node_stack[stack_size] = right_idx
sample_start[stack_size] = partition_pos
sample_end[stack_size] = end_idx
depth_stack[stack_size] = depth + 1
stack_size += 1
# Push left child
node_stack[stack_size] = left_idx
sample_start[stack_size] = start_idx
sample_end[stack_size] = partition_pos
depth_stack[stack_size] = depth + 1
stack_size += 1
return is_leaf, features, thresholds, values, left_children, right_children
@njit(parallel=True)
def predict_all(X, is_leaf, features, thresholds, values, left_children, right_children):
"""Make predictions for all samples in parallel."""
n_samples = X.shape[0]
predictions = np.empty(n_samples, dtype=np.float64)
for i in prange(n_samples):
node_idx = 0
while not is_leaf[node_idx]:
if X[i, features[node_idx]] <= thresholds[node_idx]:
node_idx = left_children[node_idx]
else:
node_idx = right_children[node_idx]
predictions[i] = values[node_idx]
return predictions
class NumbaDecisionTreeRegressor(BaseEstimator, RegressorMixin):
"""A decision tree regressor implemented with Numba and prefix-sum optimization, with sample weights."""
def __init__(self, max_depth=5, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.is_leaf = None
self.features = None
self.thresholds = None
self.values = None
self.left_children = None
self.right_children = None
def fit(self, X, y, sample_weight=None):
X = np.asarray(X, dtype=np.float64)
y = np.asarray(y, dtype=np.float64)
if sample_weight is None:
# If no sample_weight is given, use equal weights.
sample_weight = np.ones(len(y), dtype=np.float64)
else:
sample_weight = np.asarray(sample_weight, dtype=np.float64)
(self.is_leaf,
self.features,
self.thresholds,
self.values,
self.left_children,
self.right_children) = build_tree(X, y, sample_weight, self.max_depth, self.min_samples_split)
return self
def predict(self, X):
X = np.asarray(X, dtype=np.float64)
return predict_all(X, self.is_leaf, self.features, self.thresholds, self.values,
self.left_children, self.right_children)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment