Skip to content

Instantly share code, notes, and snippets.

@ResidentMario
Created March 14, 2019 22:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ResidentMario/36a49a132a74465047948d4ac66077e1 to your computer and use it in GitHub Desktop.
Save ResidentMario/36a49a132a74465047948d4ac66077e1 to your computer and use it in GitHub Desktop.
import numpy as np
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.base import BaseEstimator
class KerasBatchClassifier(KerasClassifier, BaseEstimator):
def __init__(self, model, **kwargs):
super().__init__(model)
self.fit_kwargs = kwargs
self._estimator_type = 'classifier'
def fit(self, *args, **kwargs):
# taken from keras.wrappers.scikit_learn.KerasClassifier.fit
self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
self.classes_ = np.array(range(len(self.fit_kwargs['train_generator'].class_indices)))
self.__history = self.model.fit_generator(
self.fit_kwargs.pop('train_generator'),
**self.fit_kwargs
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment