Skip to content

Instantly share code, notes, and snippets.

Last active March 15, 2023 19:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ShenZhouHong/7ac88b1e5d8666b65b423b8f8cf433e2 to your computer and use it in GitHub Desktop.
Save ShenZhouHong/7ac88b1e5d8666b65b423b8f8cf433e2 to your computer and use it in GitHub Desktop.
Utility function which graphs a pairwise comparison grid from pd.DataFrame inputs.
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
def pairplot(
df: pd.DataFrame,
target: pd.Series,
names : Union[None, list[str]] = None,
size : Union[None, tuple[int, int]] = None,
cmap : mpl.colors.LinearSegmentedColormap =,
hist_y: Literal["avg", "std", "all"] = "avg",
hist_bins: int = 20
) -> mpl.figure.Figure:
Given a set of training data as a Pandas DataFrame `df`, and labels as a
Pandas Series `target`, this function will plot a grid of subplots that visualise
the dataset. The names of the features are required, and can be passed as a
list of strings `names` or will be inferred from the column names of the DataFrame.
The size of the plot can be specified using a tuple `size`. The colormap `cmap`
determines the color scheme used for the scatter plots. The option `hist_y`
determines how the histograms for the diagonal subplots are plotted, and can
be one of "avg", "std", or "all". `hist_bins` specifies the number of bins
to use for the histograms. Returns the Figure object.
assert isinstance(df, pd.DataFrame), f"Error: df is {type(df)}, expected pd.DataFrame."
assert isinstance(target, pd.Series), f"Error: target is {type(target)}, expected pd.Series."
assert hist_y in ("avg", "std", "all"), f"Error: hist_y must be one of 'avg', 'std', or 'all', received {hist_y}."
assert isinstance(hist_bins, int), f"Error: hist_bins is {type(hist_bins)}, expected int."
# Infer feature names from DataFrame if not provided by user
if not names:
names: list[str] = list(df)
assert isinstance(names, list), f"Error: names is {type(names)}, expected list[str]."
assert len(names) == df.shape[1], f"Error: names list has length of {len(names)}, expected {df.shape[1]}."
# These are the primary variables that we will be plotting, a matrix of features, and a vector of targets
X: np.ndarray = df.to_numpy()
y: np.ndarray = np.asarray(target.to_list(), dtype=np.float64)
# Retrieve the number of features by looking at the shape of X
features: int = X.shape[1]
# The total size of our grid (i.e. num. of subplots)
grid_size: int = features * features
# Infer size of the plot from number of features if not provided by user
if not size:
size: tuple[int, int] = tuple((features * 5, features * 5))
assert isinstance(size, tuple), f"Error: size is {type(size)}, expected tuple[int, int]."
assert len(size) == 2, f"Error: size tuple must contain exactly 2 items, currently has {len(size)}"
assert size[0] == size[1], f"Error: size tuple must be square, is currently {size[0], size[1]}"
# We create our figure object, and add the title
fig: mpl.figure.Figure = plt.figure(figsize=size)
fig.suptitle(f"Pairwise Comparison Grid for {features} by {features} Features", fontsize=16, y=1.01)
# Main loop. For each feature in the data set, add a subplot
for subplot in range(grid_size):
ax: plt.Axes = fig.add_subplot(features, features, subplot + 1)
# Find the x, and y indices for the feature being plotted.
fx: int = int(subplot % features)
fy: int = int(subplot / features)
diagonal: bool = fx == fy
# Begin plotting data for the subplot. For non-diagonal cases:
if not diagonal:
plt.scatter(X[:, fx], X[:, fy], c=y, cmap=cmap)
# Label Subplots with the features using index
plt.title(f"#{subplot + 1}: {names[fx]} vs {names[fy]}")
# For diagonal cases, plot histogram:
elif diagonal:
# First, determine how we will represent the y in the distribution
if hist_y == "all":
# No separate groupings, simply plot 'all' y in histogram
dataset: list[np.ndarray] = [X[:, fx]]
colors : str = '#B6B6B4' # Light grey
elif hist_y == "avg":
# Make two brackets for y, one containing y <= mean, and one y > mean
dataset: list[np.ndarray] = [
X[y <= np.mean(y), fx],
X[y > np.mean(y), fx]
colors : np.ndarray[colors.Color] = [cmap(0), cmap(1.0)]
elif hist_y == "std":
# Make three brackets for y: 1 std dev. interval, and those above and below
lower_bound: np.float64 = np.mean(y) - np.std(y)
upper_bound: np.float64 = np.mean(y) + np.std(y)
one_std_interval: np.ndarray = y[(y >= lower_bound) & (y <= upper_bound)]
dataset: list[np.ndarray] = [
X[y <= lower_bound, fx],
X[y > upper_bound, fx]
# Infer the colour intervals from the grouping.
color_intervals: np.ndarray[float] = np.linspace(0, 1, len(dataset) + 1)
colors: np.ndarray[colors.Color] = cmap(color_intervals[:-1])
plt.hist(dataset, bins=hist_bins, color=colors)
# Label Subplots with the features using index
plt.title(f"#{subplot + 1}: Frequency of {names[fx]}")
# Once all features are plotted, we close the figure and return it.
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment