Skip to content

Instantly share code, notes, and snippets.

@conormm
Last active April 12, 2021 07:39
Show Gist options
  • Save conormm/0808922aa0f9b6581277340a5e2377f1 to your computer and use it in GitHub Desktop.
Save conormm/0808922aa0f9b6581277340a5e2377f1 to your computer and use it in GitHub Desktop.
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import OneHotEncoder
class TreeEmbeddingLogisticRegression(BaseEstimator, ClassifierMixin):
"""Fits a logistic regression model on tree embeddings.
"""
def __init__(self, **kwargs):
self.kwargs = kwargs
self.gbm = GradientBoostingClassifier(**kwargs)
self.lr = LogisticRegression(penalty="l1", solver="liblinear")
self.bin = OneHotEncoder()
def fit(self, X, y=None):
self.gbm.fit(X, y)
X_emb = self.gbm.apply(X).reshape(X.shape[0], -1)
X_emb = self.bin.fit_transform(X_emb)
self.lr.fit(X_emb, y)
def predict(self, X, y=None, with_tree=False):
if with_tree:
preds = self.gbm.predict(X)
else:
X_emb = self.gbm.apply(X).reshape(X.shape[0], -1)
X_emb = self.bin.transform(X_emb)
preds = self.lr.predict(X_emb)
return preds
def predict_proba(self, X, y=None, with_tree=False):
if with_tree:
preds = self.gbm.predict_proba(X)
else:
X_emb = self.gbm.apply(X).reshape(X.shape[0], -1)
X_emb = self.bin.transform(X_emb)
preds = self.lr.predict_proba(X_emb)
return preds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment