Skip to content

Instantly share code, notes, and snippets.

@marcosfelt
Created March 10, 2023 17:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcosfelt/ad26e53edd82fd4e3aa3bcd731a48a6b to your computer and use it in GitHub Desktop.
Save marcosfelt/ad26e53edd82fd4e3aa3bcd731a48a6b to your computer and use it in GitHub Desktop.
from typing import Dict, List, Optional
import matplotlib as mpl
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import numpy as np
def parity_plot(
y,
yhat,
ax: Optional[Axes] = None,
include_parity: bool = True,
scores: Optional[Dict[str, float]] = None,
label: Optional[str] = None,
quantity_name: Optional[str] = None,
alpha: float = 0.5,
) -> Axes:
"""Make a parity plot
Parameters
----------
y : array-like
Measured values
yhat : array-like
Predicted values
ax : Axes, optional
Matplotlib axes object, by default None
include_parity : bool, optional
Whether to include a parity line, by default True
scores : Dict[str, float], optional
A dictionary with scores to display in the legend, by default None
label : str, optional
Label for the scatter plot, by default None
quantity_name : str, optional
Name of the quantity being plotted. Used for axis labels.
alpha : float, optional
Transparency of the scatter plot, by default 0.5
Returns
-------
A matplotlib axes object
"""
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
score_display = {"r2": "$R^2$", "mae": "MAE", "rmse": "RMSE", "mse": "MSE"}
if scores is not None:
score_label = ", ".join(
[
f"{score_display.get(score_name, score_name)}={score:.02f}"
for score_name, score in scores.items()
]
)
full_label = f"{label} ({score_label})"
else:
full_label = label
# Scatter plot
ax.scatter(y, yhat, label=full_label, alpha=alpha)
# Parity line
if include_parity:
combined = np.vstack((y, yhat))
min_y = np.min(combined)
max_y = np.max(combined)
ax.plot([min_y, max_y], [min_y, max_y], "k--")
# Formatting
if label and not quantity_name:
quantity_name = label
quantity_name = quantity_name or ""
ax.set_ylabel(f"Predicted {quantity_name}")
ax.set_xlabel(f"Measured {quantity_name}")
ax.tick_params(direction="in")
if scores is not None:
ax.legend()
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment