Skip to content

Instantly share code, notes, and snippets.

@tueda
Created December 5, 2023 10:05
Show Gist options
  • Save tueda/6be665af6edcdcf114376e9eb9196c1d to your computer and use it in GitHub Desktop.
Save tueda/6be665af6edcdcf114376e9eb9196c1d to your computer and use it in GitHub Desktop.
def plot_data(data, *, figsize=None, columns=None, col_wrap=5, cat_threshold=5):
import matplotlib.pyplt as plt
import seaborn as sns
import pandas as pd
if not columns:
columns = data.columns
n_rows = len(columns) // col_wrap + 1
if not figsize:
figsize = (col_wrap * 4, n_rows * 3)
plt.figure(figsize=figsize)
for i, c in enumerate(columns):
ax = plt.subplot2grid((n_rows, col_wrap), (i // col_wrap, i % col_wrap))
if (
pd.api.types.is_numeric_dtype(data[c])
or len(data[c].unique()) <= cat_threshold
):
sns.histplot(data[c], ax=ax)
else:
ax.tick_params(
labelbottom=False,
labelleft=False,
labelright=False,
labeltop=False,
bottom=False,
left=False,
right=False,
top=False,
)
ax.set_xlabel(c)
ax.text(
0.5,
0.5,
f"{len(data[c].unique())} unique values",
bbox={"facecolor": "white", "alpha": 1, "edgecolor": "none", "pad": 1},
ha="center",
va="center",
transform=ax.transAxes,
)
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment