Created
April 21, 2021 14:10
-
-
Save gvyshnya/b825471dd9c92c6cd4548410ba7fbd2b to your computer and use it in GitHub Desktop.
Custom class for Ensemble Classifier on top of lightgbm, xgboost, and catboost
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EnsembleModel: | |
def __init__(self, params): | |
""" | |
LGB + XGB + CatBoost model | |
""" | |
self.lgb_params = params['lgb'] | |
self.xgb_params = params['xgb'] | |
self.cat_params = params['cat'] | |
self.lgb_model = LGBMClassifier(**self.lgb_params) | |
self.xgb_model = XGBClassifier(**self.xgb_params) | |
self.cat_model = CatBoostClassifier(**self.cat_params) | |
def fit(self, x, y, *args, **kwargs): | |
return (self.lgb_model.fit(x, y, *args, **kwargs), | |
self.xgb_model.fit(x, y, *args, **kwargs), | |
self.cat_model.fit(x, y, *args, **kwargs)) | |
def predict(self, x, weights=[1.0, 1.0, 1.0]): | |
""" | |
Generate model predictions | |
:param x: data | |
:param weights: weights on model prediction, first one is the weight on lgb model | |
:return: array with predictions | |
""" | |
return np.rint((weights[0] * self.lgb_model.predict(x) + | |
weights[1] * self.xgb_model.predict(x) + | |
weights[2] * self.cat_model.predict(x)) / 3) | |
def predict_proba(self, x, weights=[1.0, 1.0, 1.0]): | |
""" | |
Generate model class label probability predictions | |
:param x: data | |
:param weights: weights on model prediction, first one is the weight on lgb model | |
:return: array with predictions | |
""" | |
return np.rint((weights[0] * self.lgb_model.predict_proba(x) + | |
weights[1] * self.xgb_model.predict_proba(x) + | |
weights[2] * self.cat_model.predict_proba(x)) / 3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment