Last active
January 22, 2021 17:54
-
-
Save data-hound/ea434feb38e47a8c5e0a4e8b391ebe0a to your computer and use it in GitHub Desktop.
Scikeras Tutorial - 3: Multi Output Transformer for CapsNet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from sklearn.base import BaseEstimator, TransformerMixin | |
from sklearn.preprocessing import LabelEncoder, FunctionTransformer, OneHotEncoder | |
class MultiOutputTransformer(BaseEstimator, TransformerMixin): | |
def fit(self, y): | |
# Separate the two different 'y's into two arrays | |
y_caps, y_recons = y[:,:10],y[:,10:] | |
# Create internal encoders. | |
# Since, the data is already passed as one-hot encoded targets, no transformers are used. | |
# self.caps_encoder_ = OneHotEncoder() | |
# self.recons_encoder_ = FunctionTransformer(func=lambda t: t) | |
# Fit the transformers(if defined) to the input data. | |
# self.caps_encoder_.fit(y_caps) | |
# self.recons_encoder_.fit(y_recons) | |
# Save the number of classes | |
# Can be skipped for this instance. | |
self.n_classes_ = [ | |
y_caps.shape[0], | |
y_recons.shape[1], | |
] | |
# Save number of expected outputs in the Keras model | |
# SciKeras will automatically use this to do error-checking | |
self.n_outputs_expected_ = 2 | |
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def transform(self, y: np.ndarray) -> List[np.ndarray]: | |
# separate the 'y's | |
y_caps, y_recons = y[:,:10],y[:,10:] | |
# Apply transformers to input array | |
# Not needed to implement if no encoders defined. | |
# y_caps = self.caps_encoder_.transform(y_caps) | |
# y_recons = self.recons_encoder_.transform(y_recons) | |
# Split the data into a list | |
return [y_caps, y_recons] | |
def inverse_transform(self, y: List[np.ndarray], return_proba: bool = False) -> np.ndarray: | |
y_pred_proba = y # rename for clarity, what Keras gives us are probs | |
if return_proba: | |
# if the output needed is in the form of probabilities, simply stack along columns | |
return np.column_stack(y_pred_proba, axis=1) | |
# Convert the class probabilities to OHE vectors for capsules. The reconstruction output needs no processing. | |
y_pred_caps = to_categorical(np.argmax(y_pred_proba[0], axis=1), num_classes=y_pred_proba[0].shape[1]) | |
y_pred_recons = y_pred_proba[1] | |
# Pass back through LabelEncoder - not needed in this implementation | |
# y_pred_caps = self.caps_encoder_.inverse_transform(y_pred_caps) | |
# y_pred_recons = self.recons_encoder_.inverse_transform(y_pred_recons) | |
return np.column_stack([y_pred_caps, y_pred_recons]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_metadata(self): | |
return { | |
"n_classes_": self.n_classes_, | |
"n_outputs_expected_": self.n_outputs_expected_, | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment