Skip to content

Instantly share code, notes, and snippets.

@beans15
Last active January 22, 2020 13:11
Show Gist options
  • Save beans15/e5d17bf80117917ee4f9e887e99ffe5e to your computer and use it in GitHub Desktop.
Save beans15/e5d17bf80117917ee4f9e887e99ffe5e to your computer and use it in GitHub Desktop.
コルーチンを使ったscikit-learn Estimatorのアダプター
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from dataclasses import dataclass
from typing import Generator, Callable, Any
import numpy as np
@dataclass
class SimpleEstimator(BaseEstimator):
func: Callable[[np.ndarray, np.ndarray], Generator[Any, np.ndarray, np.ndarray]]
intermediate_value = None
_gen = None
def fit(self, x, y):
self._gen = self.func(x, y)
self.intermediate_value = next(self._gen)
return self
def predict(self, x, y=None):
return self._gen.send(x)
class SimpleClassifier(SimpleEstimator, ClassifierMixin):
pass
class SimpleRegressor(SimpleEstimator, RegressorMixin):
pass
### Usage ###
import statsmodels.api as sm
@SimpleRegressor
def estimator(x, y):
family = sm.families.Poisson()
m = sm.GLM(y, x, family)
result = m.fit_regularized(L1_wt=1.0, alpha=1.0)
x_val = yield result
while True:
x_val = yield result.predict(x_val)
x_train, x_test, y_train, y_test = # ...
estimator.fit(x_train, y_train)
print(f"Train score: {estimator.score(x_train, y_train):.4f}")
print(f"Test score: {estimator.score(x_test, y_test):.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment