Skip to content

Instantly share code, notes, and snippets.

@arose13
Created August 20, 2019 18:50
Show Gist options
  • Save arose13/362f5a1ca95ce475ef45a9b61e18c5ea to your computer and use it in GitHub Desktop.
Save arose13/362f5a1ca95ce475ef45a9b61e18c5ea to your computer and use it in GitHub Desktop.
Computing the mean of a particular model, conditional on some categorical variable
import pandas as pd
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.preprocessing import OneHotEncoder
from sklearn.exceptions import NotFittedError
class StratifiedDummyRegressor(BaseEstimator, RegressorMixin):
"""
An extremely scalable dummy regression model for computing the mean for each group specified by a column.
Single core 3.4Ghz
(+1e8 rows, +1e4 cardinality) in < 1 minute
"""
def __init__(self, stratified_col=None):
self.stratified_col = stratified_col # type: str
self.preprocessor = OneHotEncoder(categories='auto')
self.coef_, self.results_ = 2*[None]
def _solve_means(self, x_one_hot, y):
"""
Sparse solution for m that minimizes the difference from y = Xm
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsqr.html
"""
from scipy.sparse.linalg import lsqr
self.coef_, *self.results_ = lsqr(x_one_hot, y)
self.results_ = {k: v for k, v in zip(
['istop', 'itn', 'r1norm', 'r2norm', 'anorm', 'acond', 'arnorm', 'xnorm'],
self.results_
)}
def fit(self, X, y):
if isinstance(X, pd.DataFrame):
x = self.preprocessor.fit_transform(X[[self.stratified_col]].values)
self._solve_means(x, y)
else:
raise NotImplementedError('X must be a pd.DataFrame with named cols')
return self
def predict(self, X):
if self.coef_ is None:
raise NotFittedError('Call fit() first')
x = self.preprocessor.transform(X[[self.stratified_col]].values)
return x @ self.coef_
def score(self, X, y, sample_weight=None):
from sklearn.metrics import r2_score
return r2_score(y, self.predict(X))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment