Skip to content

Instantly share code, notes, and snippets.

@david-a-parry
Last active August 17, 2020 15:54
Show Gist options
  • Save david-a-parry/350fdcea39d248a808b696b151beced0 to your computer and use it in GitHub Desktop.
Save david-a-parry/350fdcea39d248a808b696b151beced0 to your computer and use it in GitHub Desktop.
import seaborn as sns
from pandas.api.types import is_float_dtype, is_integer_dtype
import matplotlib.pyplot as plt
%matplotlib inline
def add_quantile_lines(df, ax, x, y, median_width=0.4, quantile_width=0.25,
mean_and_sd=False, mean_and_sem=False, color='k',
line_width=3, alpha=0.8, zorder=9):
if mean_and_sd and mean_and_sem:
raise ValueError("mean_and_sd and mean_and_sem arguments are " +
"mutually exclusive")
for tick, text in zip(ax.get_xticks(), ax.get_xticklabels()):
x_val = text.get_text()
if is_float_dtype(df[x]):
x_val = float(x_val)
elif is_integer_dtype(df[x]):
x_val = int(x_val)
if mean_and_sd or mean_and_sem:
mid_val = df[df[x] == x_val][y].mean()
if mean_and_sd:
err = df[df[x] == x_val][y].std()
else:
err = df[df[x] == x_val][y].sem()
first_val = mid_val - err
third_val = mid_val + err
else:
mid_val = df[df[x] == x_val][y].median()
first_val = df[df[x] == x_val][y].quantile(q=0.25)
third_val = df[df[x] == x_val][y].quantile(q=0.75)
# plot horizontal lines across the column, centered on the tick
ax.plot([tick-median_width/2, tick+median_width/2],
[mid_val, mid_val],
zorder=zorder, alpha=alpha,
lw=line_width, color='k')
ax.plot([tick-quantile_width/2, tick+quantile_width/2],
[first_val, first_val],
zorder=zorder, alpha=alpha,
lw=line_width, color='k')
ax.plot([tick-quantile_width/2, tick+quantile_width/2],
[third_val, third_val],
zorder=zorder, alpha=alpha,
lw=line_width, color='k')
ax.plot([tick, tick], [first_val, third_val],
zorder=zorder, alpha=alpha,
lw=line_width, color='k')
return ax
def superplot(df, x, y, replicate, figsize=(6, 4),
palette=None, mean_palette=None, size=5, mean_size=12,
alpha=1.0, mean_alpha=1.0, show_legend=False):
mean_agg = df.groupby([x, replicate], as_index=False).agg({y: "mean"})
plt.figure(figsize=figsize)
ax = sns.swarmplot(data=df, x=x, y=y, size=size,
hue=replicate, alpha=alpha, dodge=False,
linewidth=0, palette=palette)
sns.swarmplot(data=mean_agg, x=x, y=y, ax=ax,
hue=replicate, alpha=mean_alpha, dodge=False, edgecolor='k',
linewidth=2, size=mean_size, palette=mean_palette,
zorder=10)
if show_legend:
handles, labels = ax.get_legend_handles_labels()
n = len(df[replicate].unique())
ax.legend(handles[:n], labels[:n], title=replicate)
else:
ax.legend().remove()
return ax
sns.set_style("whitegrid")
tips = sns.load_dataset("tips")
# swarmplot with quantile lines
ax = sns.swarmplot(x="time", y="total_bill", data=tips)
add_quantile_lines(tips, ax, 'time', 'total_bill', alpha=0.65)
# superplot style plot
ax = superplot(tips, 'smoker', 'total_bill', 'day', show_legend=True)
# superplot plus quantile lines
ax = superplot(tips, 'smoker', 'total_bill', 'day')
add_quantile_lines(tips, ax, 'smoker', 'total_bill')
# superplot plus mean +/- sd lines
ax = superplot(tips, 'smoker', 'total_bill', 'day')
add_quantile_lines(tips, ax, 'smoker', 'total_bill', mean_and_sd=True)
# superplot plus mean +/- sem lines
ax = superplot(tips, 'smoker', 'total_bill', 'day')
add_quantile_lines(tips, ax, 'smoker', 'total_bill', mean_and_sem=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment