Skip to content

Instantly share code, notes, and snippets.

@marcosfelt
Created March 12, 2024 22:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcosfelt/acb71ebbc3b22d5e73142c828b12c959 to your computer and use it in GitHub Desktop.
Save marcosfelt/acb71ebbc3b22d5e73142c828b12c959 to your computer and use it in GitHub Desktop.
from typing import Dict, List, Literal
from scipy.stats import linregress, spearmanr
from sklearn import metrics
SCORE_NAMES = ["mae", "mse", "rmse", "mape", "r2", "maxe", "expl_var"]
def calculate_metrics(
y_true, y_pred, scores: List[str] = SCORE_NAMES
) -> Dict[str, float]:
"""Calculate metrics on a given dataset."""
def _get_score(score):
# calculate metric values
if score == "mae":
return metrics.mean_absolute_error(y_true, y_pred)
if score == "mse":
return metrics.mean_squared_error(y_true, y_pred)
if score == "rmse":
return metrics.mean_squared_error(y_true, y_pred) ** (1 / 2)
if score == "mape":
return metrics.mean_absolute_percentage_error(y_true, y_pred)
if score == "r2":
try:
return rsquared(y_true, y_pred)
except ValueError:
return metrics.r2_score(y_true, y_pred)
if score == "maxe":
return [
metrics.max_error(y_true[:, i], y_pred[:, i])
for i in range(y_true.shape[1])
]
if score == "expl_var":
return metrics.explained_variance_score(y_true, y_pred)
if score == "spearman":
return spearmanr(y_true, y_pred).correlation # type: ignore
result_metric = {}
for s in scores:
result_metric[s] = _get_score(s)
return result_metric
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment