Skip to content

Instantly share code, notes, and snippets.

@data-hound
Last active January 22, 2021 17:55
Show Gist options
  • Save data-hound/33de8bcd5020d81236f968a837a23153 to your computer and use it in GitHub Desktop.
Save data-hound/33de8bcd5020d81236f968a837a23153 to your computer and use it in GitHub Desktop.
Scikeras Tutorial - 4: Input transformer and MIMOEstimator
def input_reshaper(X):
return [X[:,:-10].reshape(X.shape[0],28,28,1), X[:,-10:]]
class MIMOEstimator(BaseWrapper):
@property
def feature_encoder(self):
return FunctionTransformer(
func=input_reshaper,
)
@staticmethod
def scorer(X, #should be y_true according to documentation and would be changed in next release
y, #should be y_pred according to documentation and would be changed in next release
**kwargs) -> float:
y_pred_caps, y_pred_recons = y[:,:10],y[:,10:]
y_caps, y_recons = X[:,:10],X[:,10:]
return accuracy_score(y_caps, y_pred_caps)
@property
def target_encoder(self):
return MultiOutputTransformer()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment