Skip to content

Instantly share code, notes, and snippets.

@jnothman
Created October 19, 2015 04:45
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jnothman/566ebde618ec18f2bea6 to your computer and use it in GitHub Desktop.
Save jnothman/566ebde618ec18f2bea6 to your computer and use it in GitHub Desktop.
Generic scikit-learn estimator to cluster data and build predictive models for each cluster.
from sklearn.base import BaseEstimator, Clone
from sklearn.utils import safe_mask
class ModelByCluster(BaseEstimator):
def __init__(self, clusterer, estimator):
self.clusterer = clusterer
self.estimator = estimator
def fit(self, X, y):
self.clusterer_ = clone(self.clusterer)
clusters = self.clusterer_.fit_predict(X)
n_clusters = len(np.unique(clusters))
self.estimators_ = []
for c in range(n_clusters):
mask = clusters == c
est = clone(self.estimator)
est.fit(X[safe_mask(X, mask)], y[safe_mask(y, mask)])
self.estimators_.append(est)
return self
def predict(self, X):
clusters = self.clusterer_.predict(X)
y_tmp = []
idx = []
for c, est in enumerate(self.estimators_):
mask = clusters == c
idx.append(np.flatnonzero(mask))
predictions.append(est.predict(X[safe_mask(X, mask)]))
y_tmp = np.concatenate(y_tmp)
idx = np.concatenate(idx)
y = np.empty_like(y_tmp)
y[idx] = y_tmp
return y
@jnothman
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment