Last active
January 24, 2023 12:30
-
-
Save vpekar/a9eee5fe8b8c3e35b03ae309d0d8c984 to your computer and use it in GitHub Desktop.
A Python implementation of Macroaveraged MAE and RMSE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Macroaveraged MAE and RMSE ([Baccianella et al 2009](http://nmis.isti.cnr.it/sebastiani/Publications/ISDA09.pdf)) for evaluation of ordinal classifiers. | |
""" | |
import numpy as np | |
def groupby_labels(y, yhat): | |
"""Based on https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function | |
""" | |
m = np.stack([y, yhat]).T | |
m = m[m[:, 0].argsort()] | |
grouped_preds = np.split(m[:, 1], np.unique(m[:, 0], return_index=True)[1])[1:] | |
labels = np.unique(m[:, 0]) | |
return labels, grouped_preds | |
def mae_macro(y, yhat): | |
"""Macroaveraged MAE | |
""" | |
labels, preds = groupby_labels(y, yhat) | |
mean_diff = np.array([np.abs(label - pred).mean() for label, pred in zip(labels, preds)]).mean() | |
return mean_diff | |
def rmse_macro(y, yhat): | |
"""Macroaveraged RMSE | |
""" | |
labels, preds = groupby_labels(y, yhat) | |
mean_diff = np.array([np.power(label - pred, 2).mean() for label, pred in zip(labels, preds)]).mean() | |
return np.sqrt(mean_diff) | |
if __name__ == "__main__": | |
y = np.array([1, 2, 3, 1, 2]) | |
yhat = np.array([3, 2, 2, 1, 2]) | |
print(mae_macro(y, yhat)) | |
print(rmse_macro(y, yhat)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment