Skip to content

Instantly share code, notes, and snippets.

@Ab1992ao
Created May 17, 2021 10:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ab1992ao/e3ea080d36d2bf2d0c1ddc17aa4b9e99 to your computer and use it in GitHub Desktop.
Save Ab1992ao/e3ea080d36d2bf2d0c1ddc17aa4b9e99 to your computer and use it in GitHub Desktop.
evaluate ner model in mltsk pipe
class TagCallback(Callback):
def __init__(self, dataset, call_model=None, name="NER_ACC"):
self.call_model = call_model
self.dataset = dataset
self.best = 0
self.name = name
super(TagCallback, self).__init__()
def on_epoch_end(self, epoch, logs=None):
if logs is None:
logs = {}
m_pr = self.call_model.predict(self.dataset[0], batch_size=64)
coef, conf_mtrx = cat_accuracy(self.dataset[1], m_pr)
if coef > self.best:
self.best = coef
print("*** New best: {} = {}".format(self.name, coef))
print('----conf matrix---- \n', conf_mtrx)
else:
print("{} = {}".format(self.name, coef))
logs['val_'+self.name] = coef
def on_train_begin(self, logs=None):
self.on_epoch_end(None)
def cat_accuracy(y_true, y_pred):
amx_true = np.argmax(y_true, axis=-1)
amx_pred = np.argmax(y_pred, axis=-1)
max_sequence_length = len(y_true[0])
matches = (amx_true == amx_pred)
score = np.sum(matches, axis=-1) / max_sequence_length
score = np.round(np.mean(score), 4)
all_t = np.hstack(amx_true)
all_p = np.hstack(amx_pred)
conf_mat = classification_report(all_t, all_p, target_names = tag2idx.keys())
return score, conf_mat
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment