Skip to content

Instantly share code, notes, and snippets.

@mintaow
Last active June 11, 2022 23:11
Show Gist options
  • Save mintaow/cbc62ce869163675e56ee453bcc6015d to your computer and use it in GitHub Desktop.
Save mintaow/cbc62ce869163675e56ee453bcc6015d to your computer and use it in GitHub Desktop.
# Matplotlib version: 3.5.2 (My Local Jupyter Notebook)
# Seaborn version: 0.11.2 (My Local Jupyter Notebook)
# Python version: 3.7.4 (My Local Jupyter Notebook)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set(style="white",context="talk")
def dist_plot(data, bw_adjust, clip, bins, title='TITLE', xlabel = 'XLABEL', ylabel = 'Kernel Density', figsize=(14, 6)):
"""
This dist_plot function visualizes the distribution of the data in histograms and KDE curve.
Input:
data: pandas series. The variable.
bw_adjust: float. KDE parameter. Factor that multiplicatively scales the value chosen using bw_method. Increasing will make the curve smoother.
clip: tuple. KDE parameter. Do not evaluate the kernel density outside of these limits.
bins: int. Number of bars.
title: string. Title of the graph.
xlabel: string. The title of the x-axis
ylabel: string. The title of the y-axis. Note that y-axis stands for the Kernel Density, which corresponds to the histogram bin width and thus could take values greater than 1, rather than the actual probability.
figsize: tuple. The width and height of the figure.
Output:
fig: figure object.
"""
fig,ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
sns.distplot(data,
label = 'XXX',
ax = ax,
color = 'royalblue',
bins = bins,
kde = True,
kde_kws = {
'bw_adjust': bw_adjust, # Important parameter!
# Factor, multiplied by the smoothing bandwidth, that determines how far the evaluation grid extends past the extreme datapoints.
# When set to 0, truncate the curve at the data limits.
'cut': 0, # Important parameter!
'clip': clip # Important parameter!
},
hist_kws={
'histtype': 'bar',
'color': 'grey',
'edgecolor': 'black',
'linewidth': 1,
'alpha': 0.2
}
)
ax.set_title(title, fontsize = 18, pad = 10)
ax.set_xlabel(xlabel, fontsize = 16)
ax.set_ylabel(ylabel, fontsize = 16)
ax.grid(linestyle = "--", alpha=0.4)
plt.show()
return fig
# ====================================================================================
# Illustration
# Load the data
tips = sns.load_dataset('tips')
tips['tip_perc']=tips.tip/tips.total_bill
fig = dist_plot(
data = tips['tip_perc'],
bw_adjust = 0.6,
clip = (0.0,1.0),
bins = 20,
title='Distribution of Tipping Percentage',
xlabel = 'Proportion',
ylabel = 'Kernel Density',
figsize=(14, 6)
)
fig.savefig(
fname = "../dist_plot_tips.png", # path&filename for the output
dsi = 300, # make it a high-resolution graph
bbox_inches='tight' # sometimes default savefig method cuts off the x-axis and x-ticks. This param avoids that
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment