Skip to content

Instantly share code, notes, and snippets.

@mintaow
Last active July 20, 2022 17:26
Show Gist options
  • Save mintaow/5dd5688855a9c931831be06da9267442 to your computer and use it in GitHub Desktop.
Save mintaow/5dd5688855a9c931831be06da9267442 to your computer and use it in GitHub Desktop.
Viz 2: Scatter Plot with Fitted Trendlines
# Matplotlib version: 3.5.2 (My Local Jupyter Notebook)
# Seaborn version: 0.11.2 (My Local Jupyter Notebook)
# Python version: 3.7.4 (My Local Jupyter Notebook)
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from datetime import datetime
from scipy.optimize import curve_fit
import sympy as sym
%matplotlib inline
sns.set(style="white",context="talk")
def func_exp(x, a, b, c):
return a * np.exp(b * x) + c
def func_log(x, a, b, c):
return a * np.log(b * x) + c
def func_poly2(x, a, b, c):
return a*x + b*x**2 + c
def func_poly3(x, a, b, c, d):
return a*x + b*x**2 + c*x**3 + d
def scatter_plot_with_fitted_line(df, col_x, col_y, fitted_func = func_poly3, title='TITLE', xlabel = 'XLABEL', ylabel = 'YLABEL', figsize=(14, 6)):
"""
This scatterplot function visualizes the distribution each individual data points by two variables
as well as a fitted trendline.
Input:
df: pandas dataframe. The dataset.
col_x: string. The first variable name from df as the x-axis.
col_y: string, The second variable name from df as the y-axis.
fitted_func: function. The desired fitted function.
title: string.
xlabel: string. The title of the x-axis
ylabel: string. The title of the y-axis
figsize: tuple. The width and height of the figure.
Output:
fig: figure object.
"""
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
# Scatter Plot
sns.scatterplot(
x=col_x,
y=col_y,
data = df,
ax=ax,
size=10,
legend = None,
linewidth = 0, # get ride of the surrounding frame
alpha=0.6, # make it transparent so I could tell the density by how dark the color is.
color = 'seagreen'
)
# Get the fitted parameters
popt, _ = curve_fit(fitted_func, df[col_x], df[col_y])
xs = sym.Symbol('x')
# CHANGE THE LATEX ANNOTATION IF OTHER FITTING FUNCTIONS ARE USED
tex = sym.latex(fitted_func(xs,*[round(i,4) for i in popt])).replace('$', '')
# Fitted Trendline
sns.lineplot(
x = df[col_x],
y = fitted_func(df[col_x], *popt),
ax = ax,
# Use sympy to generate the LaTeX syntex of the function
label = r'Fitted Trendline: $f(x)= %s$' %(tex),
color = 'black',
alpha = 0.5
)
ax.lines[0].set_linestyle("--")
ax.set_title(title,fontsize=18, pad = 10)
ax.set_xlabel(xlabel, fontsize = 16)
ax.set_ylabel(ylabel, fontsize = 16)
ax.grid(linestyle="--", alpha=0.4)
plt.show()
return fig
# ====================================================================================
# Illustration
# Load the data
tips = sns.load_dataset('tips')
tips['tip_perc']=tips.tip/tips.total_bill
fig = scatter_plot_with_fitted_line(
df = tips,
col_x = 'total_bill',
col_y = 'tip_perc',
fitted_func = func_poly3,
title = "Exploring Tip Proportion and Total Bill: More Expensive Meals Correlate Smaller Tipping?",
xlabel = "Total Bill Amount ($)",
ylabel = "Tip Amount ($)",
figsize = (14,6)
)
fig.savefig(
fname = "../scatter_plot_with_fitted_line_tips.png", # path&filename for the output
dsi = 300, # make it a high-resolution graph
bbox_inches='tight' # sometimes default savefig method cuts off the x-axis and x-ticks. This param avoids that
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment