Skip to content

Instantly share code, notes, and snippets.

@Ab1992ao
Created May 17, 2021 10:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ab1992ao/873227b0834ebe43c95b4b5fe029eb95 to your computer and use it in GitHub Desktop.
Save Ab1992ao/873227b0834ebe43c95b4b5fe029eb95 to your computer and use it in GitHub Desktop.
evaluate toxic task head by mltsk pipe
class AucCallback(Callback):
def __init__(self, dataset, call_model=None, savepath=None, name="AUC"):
self.call_model = call_model
self.dataset = dataset
self.best = 0
self.name = name
self.savepath = savepath
super(AucCallback, self).__init__()
def on_epoch_end(self, epoch, logs=None):
if logs is None:
logs = {}
m_pr = self.call_model.predict(self.dataset[0], batch_size=64)
coef = roc_auc_score(self.dataset[1], m_pr)
if coef > self.best:
self.best = coef
print("*** New best: {} = {}".format(self.name, coef))
if self.savepath:
self.call_model.save_weights(self.savepath)
else:
print("{} = {}".format(self.name, coef))
logs['val_'+self.name] = coef
def on_train_begin(self, logs=None):
self.on_epoch_end(None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment