Last active
July 20, 2022 17:26
-
-
Save mintaow/5dd5688855a9c931831be06da9267442 to your computer and use it in GitHub Desktop.
Viz 2: Scatter Plot with Fitted Trendlines
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Matplotlib version: 3.5.2 (My Local Jupyter Notebook) | |
# Seaborn version: 0.11.2 (My Local Jupyter Notebook) | |
# Python version: 3.7.4 (My Local Jupyter Notebook) | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as ticker | |
from datetime import datetime | |
from scipy.optimize import curve_fit | |
import sympy as sym | |
%matplotlib inline | |
sns.set(style="white",context="talk") | |
def func_exp(x, a, b, c): | |
return a * np.exp(b * x) + c | |
def func_log(x, a, b, c): | |
return a * np.log(b * x) + c | |
def func_poly2(x, a, b, c): | |
return a*x + b*x**2 + c | |
def func_poly3(x, a, b, c, d): | |
return a*x + b*x**2 + c*x**3 + d | |
def scatter_plot_with_fitted_line(df, col_x, col_y, fitted_func = func_poly3, title='TITLE', xlabel = 'XLABEL', ylabel = 'YLABEL', figsize=(14, 6)): | |
""" | |
This scatterplot function visualizes the distribution each individual data points by two variables | |
as well as a fitted trendline. | |
Input: | |
df: pandas dataframe. The dataset. | |
col_x: string. The first variable name from df as the x-axis. | |
col_y: string, The second variable name from df as the y-axis. | |
fitted_func: function. The desired fitted function. | |
title: string. | |
xlabel: string. The title of the x-axis | |
ylabel: string. The title of the y-axis | |
figsize: tuple. The width and height of the figure. | |
Output: | |
fig: figure object. | |
""" | |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) | |
# Scatter Plot | |
sns.scatterplot( | |
x=col_x, | |
y=col_y, | |
data = df, | |
ax=ax, | |
size=10, | |
legend = None, | |
linewidth = 0, # get ride of the surrounding frame | |
alpha=0.6, # make it transparent so I could tell the density by how dark the color is. | |
color = 'seagreen' | |
) | |
# Get the fitted parameters | |
popt, _ = curve_fit(fitted_func, df[col_x], df[col_y]) | |
xs = sym.Symbol('x') | |
# CHANGE THE LATEX ANNOTATION IF OTHER FITTING FUNCTIONS ARE USED | |
tex = sym.latex(fitted_func(xs,*[round(i,4) for i in popt])).replace('$', '') | |
# Fitted Trendline | |
sns.lineplot( | |
x = df[col_x], | |
y = fitted_func(df[col_x], *popt), | |
ax = ax, | |
# Use sympy to generate the LaTeX syntex of the function | |
label = r'Fitted Trendline: $f(x)= %s$' %(tex), | |
color = 'black', | |
alpha = 0.5 | |
) | |
ax.lines[0].set_linestyle("--") | |
ax.set_title(title,fontsize=18, pad = 10) | |
ax.set_xlabel(xlabel, fontsize = 16) | |
ax.set_ylabel(ylabel, fontsize = 16) | |
ax.grid(linestyle="--", alpha=0.4) | |
plt.show() | |
return fig | |
# ==================================================================================== | |
# Illustration | |
# Load the data | |
tips = sns.load_dataset('tips') | |
tips['tip_perc']=tips.tip/tips.total_bill | |
fig = scatter_plot_with_fitted_line( | |
df = tips, | |
col_x = 'total_bill', | |
col_y = 'tip_perc', | |
fitted_func = func_poly3, | |
title = "Exploring Tip Proportion and Total Bill: More Expensive Meals Correlate Smaller Tipping?", | |
xlabel = "Total Bill Amount ($)", | |
ylabel = "Tip Amount ($)", | |
figsize = (14,6) | |
) | |
fig.savefig( | |
fname = "../scatter_plot_with_fitted_line_tips.png", # path&filename for the output | |
dsi = 300, # make it a high-resolution graph | |
bbox_inches='tight' # sometimes default savefig method cuts off the x-axis and x-ticks. This param avoids that | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment