Created
March 10, 2023 17:46
-
-
Save marcosfelt/ad26e53edd82fd4e3aa3bcd731a48a6b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, List, Optional | |
import matplotlib as mpl | |
from matplotlib.axes import Axes | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def parity_plot( | |
y, | |
yhat, | |
ax: Optional[Axes] = None, | |
include_parity: bool = True, | |
scores: Optional[Dict[str, float]] = None, | |
label: Optional[str] = None, | |
quantity_name: Optional[str] = None, | |
alpha: float = 0.5, | |
) -> Axes: | |
"""Make a parity plot | |
Parameters | |
---------- | |
y : array-like | |
Measured values | |
yhat : array-like | |
Predicted values | |
ax : Axes, optional | |
Matplotlib axes object, by default None | |
include_parity : bool, optional | |
Whether to include a parity line, by default True | |
scores : Dict[str, float], optional | |
A dictionary with scores to display in the legend, by default None | |
label : str, optional | |
Label for the scatter plot, by default None | |
quantity_name : str, optional | |
Name of the quantity being plotted. Used for axis labels. | |
alpha : float, optional | |
Transparency of the scatter plot, by default 0.5 | |
Returns | |
------- | |
A matplotlib axes object | |
""" | |
if ax is None: | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
score_display = {"r2": "$R^2$", "mae": "MAE", "rmse": "RMSE", "mse": "MSE"} | |
if scores is not None: | |
score_label = ", ".join( | |
[ | |
f"{score_display.get(score_name, score_name)}={score:.02f}" | |
for score_name, score in scores.items() | |
] | |
) | |
full_label = f"{label} ({score_label})" | |
else: | |
full_label = label | |
# Scatter plot | |
ax.scatter(y, yhat, label=full_label, alpha=alpha) | |
# Parity line | |
if include_parity: | |
combined = np.vstack((y, yhat)) | |
min_y = np.min(combined) | |
max_y = np.max(combined) | |
ax.plot([min_y, max_y], [min_y, max_y], "k--") | |
# Formatting | |
if label and not quantity_name: | |
quantity_name = label | |
quantity_name = quantity_name or "" | |
ax.set_ylabel(f"Predicted {quantity_name}") | |
ax.set_xlabel(f"Measured {quantity_name}") | |
ax.tick_params(direction="in") | |
if scores is not None: | |
ax.legend() | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment