Skip to content

Instantly share code, notes, and snippets.

@asdfgeoff
Last active October 24, 2019 12:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save asdfgeoff/741de7942ff81a8986023bffb963d6df to your computer and use it in GitHub Desktop.
Save asdfgeoff/741de7942ff81a8986023bffb963d6df to your computer and use it in GitHub Desktop.
import datetime as dt
from typing import List, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
from scipy.stats import norm
plt.rcParams['figure.facecolor'] = 'white'
def get_colors(interval_start: float, interval_end: float) -> Tuple[str, str]:
""" Determine chart colors based on overlap of interval with zero. """
if interval_start > 0:
return 'darkseagreen', 'darkgreen'
elif interval_end < 0:
return 'darksalmon', 'darkred'
else:
return 'lightgray', 'gray'
def plot_single_group(ax, sub_df: pd.DataFrame) -> Tuple[float, float]:
""" Plot each row of a DataFrame on the same mpl axis object. """
ytick_labels = []
x_min, x_max = 0, 0
# Iterate over each row in group, reversing order since mpl plots from bottom up
for j, (dim, row) in enumerate(sub_df.iloc[::-1].iterrows()):
if isinstance(dim, tuple):
dim = dim[1]
# Calculate z-score for each test based on test-specific correction factor
z = norm(0, 1).ppf(1 - row.alpha / 2)
interval_start = row.uplift - (z * row.std_err)
interval_end = row.uplift + (z * row.std_err)
# Conditional coloring based on significance of result
fill_color, edge_color = get_colors(interval_start, interval_end)
ax.barh(j, [z * row.std_err, z * row.std_err],
left=[interval_start, interval_start + z * row.std_err],
height=0.8,
color=fill_color,
edgecolor=edge_color,
linewidth=0.8,
zorder=3)
ytick_labels.append(dim)
x_min = min(x_min, interval_start - 0.01)
x_max = max(x_max, interval_end + 0.01)
# Axis-specific formatting
ax.xaxis.grid(True, alpha=0.4)
ax.xaxis.set_ticks_position('none')
ax.axvline(0.00, color='black', linewidth=1.1, zorder=2)
ax.yaxis.tick_right()
ax.set_yticks(np.arange(len(sub_df)))
ax.set_yticklabels(ytick_labels)
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min-0.4, y_max+0.4)
ax.yaxis.set_ticks_position('none')
return x_min, x_max
def plot_experiment_results(df: pd.DataFrame, title: str = None, sample_size: int = None, combine_axes: bool = False) -> None:
""" Plot a (possibly MultiIndex) DataFrame on one or more matplotlib axes.
Args:
df (pd.DataFrame): DataFrame with MultiIndex representing dimensions or KPIs, and following cols: uplift, std_err, alpha
title (str): Title displayed above plot
sample_size (int): Used to add contextual information to bottom corner of plot
combine_axes (bool): If true and input df has multiindex, collapse axes together into one visible axis.
"""
plt.rcParams['figure.facecolor'] = 'white'
n_levels = len(df.index.names)
if n_levels > 2:
raise ValueError
elif n_levels == 2:
plt_rows = df.index.get_level_values(0).nunique()
else:
plt_rows = 1
# Make an axis for each group of MultiIndex DataFrame input
fig, axes = plt.subplots(nrows=plt_rows,
ncols=1,
sharex=True,
figsize=(6, 0.5 * df.shape[0] + 0.2), dpi=100)
if n_levels == 1:
ax = axes
x_min, x_max = plot_single_group(ax, df)
if n_levels == 2:
# Iterate over top-level groupings of index
x_mins, x_maxs = [], []
for i, (group, results) in enumerate(df.groupby(level=0, sort=False)):
ax = axes[i]
a, b = plot_single_group(ax, results)
x_mins.append(a)
x_maxs.append(b)
ax.set_ylabel(group)
x_min = min(x_mins)
x_max = max(x_maxs)
ax = axes[-1] # set variable back to final axis for downstream formatting functions
if combine_axes:
fig.subplots_adjust(hspace=0)
axes[0].spines['bottom'].set_visible(False)
axes[-1].spines['top'].set_visible(False)
for axis in axes[1:-1]:
axis.spines['bottom'].set_visible(False)
axis.spines['top'].set_visible(False)
ax.set_xlim(x_min, x_max)
x_tick_width = (1 + np.floor((x_max - x_min)/0.10)) / 100
loc = plticker.MultipleLocator(base=x_tick_width) # this locator puts ticks at regular intervals
ax.xaxis.set_major_locator(loc)
ax.set_xticklabels(['{:.0%}'.format(x) for x in ax.get_xticks()])
ax.set_xlabel('Uplift (relative)')
# Add title, sample size, and timestamp labels to plot
fig.text(0.5, 0.95 - 0.025 * n_levels, title, size='x-large', horizontalalignment='center')
vertical_offset = - (0.1 + 0.2 * n_levels)
timestamp_text = dt.datetime.now().strftime('Analyzed: %Y-%m-%d')
fig.text(1, vertical_offset,
timestamp_text,
size='small', color='grey',
ha='right', wrap=True, transform=ax.transAxes)
if sample_size:
sample_size_text = f'Sample size: {int(sample_size/1000)}K'
fig.text(0, vertical_offset,
sample_size_text,
size='small', color='grey',
ha='left', wrap=True, transform=ax.transAxes)
if __name__ == '__main__':
np.random.seed(24)
example_data = (pd.DataFrame({'uplift': (uniform().rvs(12) - 0.50) / 30,
'std_err': uniform(0, 0.01).rvs(12),
'dimension': ['Group A'] * 4 + ['Group B'] * 4 + ['Group C'] * 4,
'metric': ['Received rate', 'Open rate', 'Click rate', 'Purchase rate'] * 3,
'alpha': 0.05 })
.set_index(['metric', 'dimension'])
.sort_index()
.reindex(['Received rate', 'Open rate', 'Click rate', 'Purchase rate'], level=0))
plot_experiment_results(df=example_data,
title='Example email campaign (α=0.10)',
sample_size=123456,
combine_axes=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment