Skip to content

Instantly share code, notes, and snippets.

@amankharwal
Created August 21, 2020 05:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amankharwal/1f111dd4382abea4a492b5e18f4eb4b7 to your computer and use it in GitHub Desktop.
Save amankharwal/1f111dd4382abea4a492b5e18f4eb4b7 to your computer and use it in GitHub Desktop.
from kutilities.callbacks import MetricsCallback, PlottingCallback
from sklearn.metrics import f1_score, precision_score, recall_score
from keras.callbacks import ModelCheckpoint, TensorBoard
metrics = {
"f1_e": (lambda y_test, y_pred:
f1_score(y_test, y_pred, average='micro',
labels=[emotion2label['happy'],
emotion2label['sad'],
emotion2label['angry']
])),
"precision_e": (lambda y_test, y_pred:
precision_score(y_test, y_pred, average='micro',
labels=[emotion2label['happy'],
emotion2label['sad'],
emotion2label['angry']
])),
"recoll_e": (lambda y_test, y_pred:
recall_score(y_test, y_pred, average='micro',
labels=[emotion2label['happy'],
emotion2label['sad'],
emotion2label['angry']
]))
}
_datasets = {}
_datasets["dev"] = [[message_first_message_dev, message_second_message_dev, message_third_message_dev],
np.array(labels_categorical_dev)]
_datasets["val"] = [[message_first_message_val, message_second_message_val, message_third_message_val],
np.array(labels_categorical_val)]
metrics_callback = MetricsCallback(datasets=_datasets, metrics=metrics)
y_pred = model.predict([message_first_message_dev, message_second_message_dev, message_third_message_dev])
from sklearn.metrics import classification_report
for title, metric in metrics.items():
print(title, metric(labels_categorical_dev.argmax(axis=1), y_pred.argmax(axis=1)))
print(classification_report(labels_categorical_dev.argmax(axis=1), y_pred.argmax(axis=1)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment