Last active
June 11, 2022 23:10
-
-
Save mintaow/4a80393d18910bf97206fd864d0b89d0 to your computer and use it in GitHub Desktop.
Viz 4: Categorical Bar Plot Series
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Matplotlib version: 3.5.2 (My Local Jupyter Notebook) | |
# Seaborn version: 0.11.2 (My Local Jupyter Notebook) | |
# Python version: 3.7.4 (My Local Jupyter Notebook) | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
%matplotlib inline | |
sns.set(style="white",context="talk") | |
def cat_bar_plot(df,col_cat,col_y,suptitle, figsize=(14,6)): | |
''' | |
This helper function integrates boxplot, barchart, distribution-plot all together. | |
The plot_type parameter controls the output graph format. | |
Input: | |
df: pandas dataframe | |
col_cat: string. the column name of the key categorical variable (key dimension). | |
col_y: string. the column name of the dependent variable (key metric). | |
figsize: tuple. The width and height of the figure. | |
Output: | |
fig: figure object. | |
''' | |
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=figsize) | |
sns.countplot( | |
x=col_cat, | |
data=df, | |
ax=axes[0], | |
palette=sns.color_palette("summer_r",df[col_cat].nunique()) | |
) | |
axes[0].set_title(f'Number of obs. by {col_cat}', fontsize=16, pad = 5) | |
axes[0].grid(linestyle="--", alpha=0.4) | |
sns.barplot( | |
x=col_cat, | |
y=col_y, | |
data=df, | |
ci=0, | |
ax=axes[1], | |
estimator = np.mean, | |
palette=sns.color_palette("summer_r",df[col_cat].nunique()) | |
); | |
axes[1].set_title(f'{col_y} by {col_cat}', fontsize = 16, pad = 5) | |
axes[1].grid(linestyle="--", alpha=0.4) | |
fig.suptitle(suptitle, fontsize = 20) | |
plt.show() | |
return fig | |
# ==================================================================================== | |
# Illustration | |
# Load the data | |
tips = sns.load_dataset("tips") | |
tips['tip_perc']=tips.tip/tips.total_bill | |
fig = cat_bar_plot( | |
df = tips, | |
col_cat = 'sex', | |
col_y = 'tip_perc', | |
suptitle = "Examining Tipping Percentage Difference between Male and Female Customers" | |
) | |
fig.savefig( | |
fname = "../categorical_bar_plot_series.png", # path&filename for the output | |
dsi = 300, # make it a high-resolution graph | |
bbox_inches='tight' # sometimes default savefig method cuts off the x-axis and x-ticks. This param avoids that | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment