Skip to content

Instantly share code, notes, and snippets.

@ShenZhouHong
Last active March 15, 2023 19:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ShenZhouHong/7ac88b1e5d8666b65b423b8f8cf433e2 to your computer and use it in GitHub Desktop.
Save ShenZhouHong/7ac88b1e5d8666b65b423b8f8cf433e2 to your computer and use it in GitHub Desktop.
Utility function which graphs a pairwise comparison grid from pd.DataFrame inputs.
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
def pairplot(
df: pd.DataFrame,
target: pd.Series,
names : Union[None, list[str]] = None,
size : Union[None, tuple[int, int]] = None,
cmap : mpl.colors.LinearSegmentedColormap = plt.cm.coolwarm,
hist_y: Literal["avg", "std", "all"] = "avg",
hist_bins: int = 20
) -> mpl.figure.Figure:
"""
Given a set of training data as a Pandas DataFrame `df`, and labels as a
Pandas Series `target`, this function will plot a grid of subplots that visualise
the dataset. The names of the features are required, and can be passed as a
list of strings `names` or will be inferred from the column names of the DataFrame.
The size of the plot can be specified using a tuple `size`. The colormap `cmap`
determines the color scheme used for the scatter plots. The option `hist_y`
determines how the histograms for the diagonal subplots are plotted, and can
be one of "avg", "std", or "all". `hist_bins` specifies the number of bins
to use for the histograms. Returns the Figure object.
"""
assert isinstance(df, pd.DataFrame), f"Error: df is {type(df)}, expected pd.DataFrame."
assert isinstance(target, pd.Series), f"Error: target is {type(target)}, expected pd.Series."
assert hist_y in ("avg", "std", "all"), f"Error: hist_y must be one of 'avg', 'std', or 'all', received {hist_y}."
assert isinstance(hist_bins, int), f"Error: hist_bins is {type(hist_bins)}, expected int."
# Infer feature names from DataFrame if not provided by user
if not names:
names: list[str] = list(df)
else:
assert isinstance(names, list), f"Error: names is {type(names)}, expected list[str]."
assert len(names) == df.shape[1], f"Error: names list has length of {len(names)}, expected {df.shape[1]}."
# These are the primary variables that we will be plotting, a matrix of features, and a vector of targets
X: np.ndarray = df.to_numpy()
y: np.ndarray = np.asarray(target.to_list(), dtype=np.float64)
# Retrieve the number of features by looking at the shape of X
features: int = X.shape[1]
# The total size of our grid (i.e. num. of subplots)
grid_size: int = features * features
# Infer size of the plot from number of features if not provided by user
if not size:
size: tuple[int, int] = tuple((features * 5, features * 5))
else:
assert isinstance(size, tuple), f"Error: size is {type(size)}, expected tuple[int, int]."
assert len(size) == 2, f"Error: size tuple must contain exactly 2 items, currently has {len(size)}"
assert size[0] == size[1], f"Error: size tuple must be square, is currently {size[0], size[1]}"
# We create our figure object, and add the title
fig: mpl.figure.Figure = plt.figure(figsize=size)
fig.suptitle(f"Pairwise Comparison Grid for {features} by {features} Features", fontsize=16, y=1.01)
# Main loop. For each feature in the data set, add a subplot
for subplot in range(grid_size):
ax: plt.Axes = fig.add_subplot(features, features, subplot + 1)
# Find the x, and y indices for the feature being plotted.
fx: int = int(subplot % features)
fy: int = int(subplot / features)
diagonal: bool = fx == fy
# Begin plotting data for the subplot. For non-diagonal cases:
if not diagonal:
plt.scatter(X[:, fx], X[:, fy], c=y, cmap=cmap)
# Label Subplots with the features using index
plt.title(f"#{subplot + 1}: {names[fx]} vs {names[fy]}")
plt.xlabel(f"{names[fx]}")
plt.ylabel(f"{names[fy]}")
# For diagonal cases, plot histogram:
elif diagonal:
# First, determine how we will represent the y in the distribution
if hist_y == "all":
# No separate groupings, simply plot 'all' y in histogram
dataset: list[np.ndarray] = [X[:, fx]]
colors : str = '#B6B6B4' # Light grey
elif hist_y == "avg":
# Make two brackets for y, one containing y <= mean, and one y > mean
dataset: list[np.ndarray] = [
X[y <= np.mean(y), fx],
X[y > np.mean(y), fx]
]
colors : np.ndarray[colors.Color] = [cmap(0), cmap(1.0)]
elif hist_y == "std":
# Make three brackets for y: 1 std dev. interval, and those above and below
lower_bound: np.float64 = np.mean(y) - np.std(y)
upper_bound: np.float64 = np.mean(y) + np.std(y)
one_std_interval: np.ndarray = y[(y >= lower_bound) & (y <= upper_bound)]
dataset: list[np.ndarray] = [
X[y <= lower_bound, fx],
one_std_interval,
X[y > upper_bound, fx]
]
# Infer the colour intervals from the grouping.
color_intervals: np.ndarray[float] = np.linspace(0, 1, len(dataset) + 1)
colors: np.ndarray[colors.Color] = cmap(color_intervals[:-1])
plt.hist(dataset, bins=hist_bins, color=colors)
# Label Subplots with the features using index
plt.title(f"#{subplot + 1}: Frequency of {names[fx]}")
plt.xlabel(f"{names[fx]}")
plt.ylabel("Frequency")
# Once all features are plotted, we close the figure and return it.
plt.tight_layout()
plt.close()
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment