Last active
March 15, 2023 19:35
-
-
Save ShenZhouHong/7ac88b1e5d8666b65b423b8f8cf433e2 to your computer and use it in GitHub Desktop.
Utility function which graphs a pairwise comparison grid from pd.DataFrame inputs.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
def pairplot( | |
df: pd.DataFrame, | |
target: pd.Series, | |
names : Union[None, list[str]] = None, | |
size : Union[None, tuple[int, int]] = None, | |
cmap : mpl.colors.LinearSegmentedColormap = plt.cm.coolwarm, | |
hist_y: Literal["avg", "std", "all"] = "avg", | |
hist_bins: int = 20 | |
) -> mpl.figure.Figure: | |
""" | |
Given a set of training data as a Pandas DataFrame `df`, and labels as a | |
Pandas Series `target`, this function will plot a grid of subplots that visualise | |
the dataset. The names of the features are required, and can be passed as a | |
list of strings `names` or will be inferred from the column names of the DataFrame. | |
The size of the plot can be specified using a tuple `size`. The colormap `cmap` | |
determines the color scheme used for the scatter plots. The option `hist_y` | |
determines how the histograms for the diagonal subplots are plotted, and can | |
be one of "avg", "std", or "all". `hist_bins` specifies the number of bins | |
to use for the histograms. Returns the Figure object. | |
""" | |
assert isinstance(df, pd.DataFrame), f"Error: df is {type(df)}, expected pd.DataFrame." | |
assert isinstance(target, pd.Series), f"Error: target is {type(target)}, expected pd.Series." | |
assert hist_y in ("avg", "std", "all"), f"Error: hist_y must be one of 'avg', 'std', or 'all', received {hist_y}." | |
assert isinstance(hist_bins, int), f"Error: hist_bins is {type(hist_bins)}, expected int." | |
# Infer feature names from DataFrame if not provided by user | |
if not names: | |
names: list[str] = list(df) | |
else: | |
assert isinstance(names, list), f"Error: names is {type(names)}, expected list[str]." | |
assert len(names) == df.shape[1], f"Error: names list has length of {len(names)}, expected {df.shape[1]}." | |
# These are the primary variables that we will be plotting, a matrix of features, and a vector of targets | |
X: np.ndarray = df.to_numpy() | |
y: np.ndarray = np.asarray(target.to_list(), dtype=np.float64) | |
# Retrieve the number of features by looking at the shape of X | |
features: int = X.shape[1] | |
# The total size of our grid (i.e. num. of subplots) | |
grid_size: int = features * features | |
# Infer size of the plot from number of features if not provided by user | |
if not size: | |
size: tuple[int, int] = tuple((features * 5, features * 5)) | |
else: | |
assert isinstance(size, tuple), f"Error: size is {type(size)}, expected tuple[int, int]." | |
assert len(size) == 2, f"Error: size tuple must contain exactly 2 items, currently has {len(size)}" | |
assert size[0] == size[1], f"Error: size tuple must be square, is currently {size[0], size[1]}" | |
# We create our figure object, and add the title | |
fig: mpl.figure.Figure = plt.figure(figsize=size) | |
fig.suptitle(f"Pairwise Comparison Grid for {features} by {features} Features", fontsize=16, y=1.01) | |
# Main loop. For each feature in the data set, add a subplot | |
for subplot in range(grid_size): | |
ax: plt.Axes = fig.add_subplot(features, features, subplot + 1) | |
# Find the x, and y indices for the feature being plotted. | |
fx: int = int(subplot % features) | |
fy: int = int(subplot / features) | |
diagonal: bool = fx == fy | |
# Begin plotting data for the subplot. For non-diagonal cases: | |
if not diagonal: | |
plt.scatter(X[:, fx], X[:, fy], c=y, cmap=cmap) | |
# Label Subplots with the features using index | |
plt.title(f"#{subplot + 1}: {names[fx]} vs {names[fy]}") | |
plt.xlabel(f"{names[fx]}") | |
plt.ylabel(f"{names[fy]}") | |
# For diagonal cases, plot histogram: | |
elif diagonal: | |
# First, determine how we will represent the y in the distribution | |
if hist_y == "all": | |
# No separate groupings, simply plot 'all' y in histogram | |
dataset: list[np.ndarray] = [X[:, fx]] | |
colors : str = '#B6B6B4' # Light grey | |
elif hist_y == "avg": | |
# Make two brackets for y, one containing y <= mean, and one y > mean | |
dataset: list[np.ndarray] = [ | |
X[y <= np.mean(y), fx], | |
X[y > np.mean(y), fx] | |
] | |
colors : np.ndarray[colors.Color] = [cmap(0), cmap(1.0)] | |
elif hist_y == "std": | |
# Make three brackets for y: 1 std dev. interval, and those above and below | |
lower_bound: np.float64 = np.mean(y) - np.std(y) | |
upper_bound: np.float64 = np.mean(y) + np.std(y) | |
one_std_interval: np.ndarray = y[(y >= lower_bound) & (y <= upper_bound)] | |
dataset: list[np.ndarray] = [ | |
X[y <= lower_bound, fx], | |
one_std_interval, | |
X[y > upper_bound, fx] | |
] | |
# Infer the colour intervals from the grouping. | |
color_intervals: np.ndarray[float] = np.linspace(0, 1, len(dataset) + 1) | |
colors: np.ndarray[colors.Color] = cmap(color_intervals[:-1]) | |
plt.hist(dataset, bins=hist_bins, color=colors) | |
# Label Subplots with the features using index | |
plt.title(f"#{subplot + 1}: Frequency of {names[fx]}") | |
plt.xlabel(f"{names[fx]}") | |
plt.ylabel("Frequency") | |
# Once all features are plotted, we close the figure and return it. | |
plt.tight_layout() | |
plt.close() | |
return fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment