Skip to content

Instantly share code, notes, and snippets.

@tcbegley
Created March 1, 2021 19:22
Show Gist options
  • Save tcbegley/c9119131fa5324034186fa311f41284a to your computer and use it in GitHub Desktop.
Save tcbegley/c9119131fa5324034186fa311f41284a to your computer and use it in GitHub Desktop.
"""
Fairlearn expects scikit-learn models. This submodule has some shims to work
around that issue.
"""
import numpy as np
import tensorflow as tf
class KerasWrapper:
def __init__(self, model, model_path="/tmp/corrected-model.h5"):
self.model = model
self.model_path = model_path
def fit(self, X, y, sample_weight=None):
self.model.compile("adam", "binary_crossentropy", ["binary_accuracy"])
self.model.fit(
X,
y,
sample_weight=sample_weight,
epochs=100,
callbacks=[tf.keras.callbacks.EarlyStopping(patience=2)],
validation_split=0.2,
verbose=0,
)
def predict(self, X):
return (self.predict_proba(X) >= 0.5).astype(np.int)
def predict_proba(self, X):
return self.model.predict(X, steps=1).flatten()
def __getstate__(self):
# save the model to disk
self.model.save(self.model_path)
# copy objects state from self.__dict__
state = self.__dict__.copy()
# remove the unpicklable model
del state["model"]
return state
def __setstate__(self, state):
# restore instance attributes
self.__dict__.update(state)
# restore the model
self.model = tf.keras.models.load_model(self.model_path, compile=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment