Created
July 8, 2022 10:15
-
-
Save Mr-Geekman/414bde90bc81bc3524d08285593b8684 to your computer and use it in GitHub Desktop.
[ETNA] Binary classification example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from typing import Optional | |
import numpy as np | |
import pandas as pd | |
from etna.models.base import BaseAdapter | |
from etna.models.base import MultiSegmentModel | |
from sklearn.base import ClassifierMixin | |
class _SklearnBinaryProbAdapter(BaseAdapter): | |
def __init__(self, classifier: ClassifierMixin): | |
self.model = classifier | |
self.regressor_columns: Optional[List[str]] = None | |
def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_SklearnBinaryProbAdapter": | |
self.regressor_columns = regressors | |
try: | |
features = df[self.regressor_columns].apply(pd.to_numeric) | |
except ValueError: | |
raise ValueError("Only convertible to numeric features are accepted!") | |
target = df["target"] | |
self.model.fit(features, target) | |
return self | |
def predict(self, df: pd.DataFrame) -> np.ndarray: | |
try: | |
features = df[self.regressor_columns].apply(pd.to_numeric) | |
except ValueError: | |
raise ValueError("Only convertible to numeric features are accepted!") | |
pred = self.model.predict_proba(features)[:, 1] | |
return pred | |
def get_model(self) -> ClassifierMixin: | |
return self.model | |
class SklearnMultiSegmentBinaryProbModel(MultiSegmentModel): | |
def __init__(self, classifier: ClassifierMixin): | |
super().__init__(base_model=_SklearnBinaryProbAdapter(classifier=classifier)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment