Last active
October 24, 2019 12:56
-
-
Save asdfgeoff/741de7942ff81a8986023bffb963d6df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime as dt | |
from typing import List, Tuple | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as plticker | |
from scipy.stats import norm | |
plt.rcParams['figure.facecolor'] = 'white' | |
def get_colors(interval_start: float, interval_end: float) -> Tuple[str, str]: | |
""" Determine chart colors based on overlap of interval with zero. """ | |
if interval_start > 0: | |
return 'darkseagreen', 'darkgreen' | |
elif interval_end < 0: | |
return 'darksalmon', 'darkred' | |
else: | |
return 'lightgray', 'gray' | |
def plot_single_group(ax, sub_df: pd.DataFrame) -> Tuple[float, float]: | |
""" Plot each row of a DataFrame on the same mpl axis object. """ | |
ytick_labels = [] | |
x_min, x_max = 0, 0 | |
# Iterate over each row in group, reversing order since mpl plots from bottom up | |
for j, (dim, row) in enumerate(sub_df.iloc[::-1].iterrows()): | |
if isinstance(dim, tuple): | |
dim = dim[1] | |
# Calculate z-score for each test based on test-specific correction factor | |
z = norm(0, 1).ppf(1 - row.alpha / 2) | |
interval_start = row.uplift - (z * row.std_err) | |
interval_end = row.uplift + (z * row.std_err) | |
# Conditional coloring based on significance of result | |
fill_color, edge_color = get_colors(interval_start, interval_end) | |
ax.barh(j, [z * row.std_err, z * row.std_err], | |
left=[interval_start, interval_start + z * row.std_err], | |
height=0.8, | |
color=fill_color, | |
edgecolor=edge_color, | |
linewidth=0.8, | |
zorder=3) | |
ytick_labels.append(dim) | |
x_min = min(x_min, interval_start - 0.01) | |
x_max = max(x_max, interval_end + 0.01) | |
# Axis-specific formatting | |
ax.xaxis.grid(True, alpha=0.4) | |
ax.xaxis.set_ticks_position('none') | |
ax.axvline(0.00, color='black', linewidth=1.1, zorder=2) | |
ax.yaxis.tick_right() | |
ax.set_yticks(np.arange(len(sub_df))) | |
ax.set_yticklabels(ytick_labels) | |
y_min, y_max = ax.get_ylim() | |
ax.set_ylim(y_min-0.4, y_max+0.4) | |
ax.yaxis.set_ticks_position('none') | |
return x_min, x_max | |
def plot_experiment_results(df: pd.DataFrame, title: str = None, sample_size: int = None, combine_axes: bool = False) -> None: | |
""" Plot a (possibly MultiIndex) DataFrame on one or more matplotlib axes. | |
Args: | |
df (pd.DataFrame): DataFrame with MultiIndex representing dimensions or KPIs, and following cols: uplift, std_err, alpha | |
title (str): Title displayed above plot | |
sample_size (int): Used to add contextual information to bottom corner of plot | |
combine_axes (bool): If true and input df has multiindex, collapse axes together into one visible axis. | |
""" | |
plt.rcParams['figure.facecolor'] = 'white' | |
n_levels = len(df.index.names) | |
if n_levels > 2: | |
raise ValueError | |
elif n_levels == 2: | |
plt_rows = df.index.get_level_values(0).nunique() | |
else: | |
plt_rows = 1 | |
# Make an axis for each group of MultiIndex DataFrame input | |
fig, axes = plt.subplots(nrows=plt_rows, | |
ncols=1, | |
sharex=True, | |
figsize=(6, 0.5 * df.shape[0] + 0.2), dpi=100) | |
if n_levels == 1: | |
ax = axes | |
x_min, x_max = plot_single_group(ax, df) | |
if n_levels == 2: | |
# Iterate over top-level groupings of index | |
x_mins, x_maxs = [], [] | |
for i, (group, results) in enumerate(df.groupby(level=0, sort=False)): | |
ax = axes[i] | |
a, b = plot_single_group(ax, results) | |
x_mins.append(a) | |
x_maxs.append(b) | |
ax.set_ylabel(group) | |
x_min = min(x_mins) | |
x_max = max(x_maxs) | |
ax = axes[-1] # set variable back to final axis for downstream formatting functions | |
if combine_axes: | |
fig.subplots_adjust(hspace=0) | |
axes[0].spines['bottom'].set_visible(False) | |
axes[-1].spines['top'].set_visible(False) | |
for axis in axes[1:-1]: | |
axis.spines['bottom'].set_visible(False) | |
axis.spines['top'].set_visible(False) | |
ax.set_xlim(x_min, x_max) | |
x_tick_width = (1 + np.floor((x_max - x_min)/0.10)) / 100 | |
loc = plticker.MultipleLocator(base=x_tick_width) # this locator puts ticks at regular intervals | |
ax.xaxis.set_major_locator(loc) | |
ax.set_xticklabels(['{:.0%}'.format(x) for x in ax.get_xticks()]) | |
ax.set_xlabel('Uplift (relative)') | |
# Add title, sample size, and timestamp labels to plot | |
fig.text(0.5, 0.95 - 0.025 * n_levels, title, size='x-large', horizontalalignment='center') | |
vertical_offset = - (0.1 + 0.2 * n_levels) | |
timestamp_text = dt.datetime.now().strftime('Analyzed: %Y-%m-%d') | |
fig.text(1, vertical_offset, | |
timestamp_text, | |
size='small', color='grey', | |
ha='right', wrap=True, transform=ax.transAxes) | |
if sample_size: | |
sample_size_text = f'Sample size: {int(sample_size/1000)}K' | |
fig.text(0, vertical_offset, | |
sample_size_text, | |
size='small', color='grey', | |
ha='left', wrap=True, transform=ax.transAxes) | |
if __name__ == '__main__': | |
np.random.seed(24) | |
example_data = (pd.DataFrame({'uplift': (uniform().rvs(12) - 0.50) / 30, | |
'std_err': uniform(0, 0.01).rvs(12), | |
'dimension': ['Group A'] * 4 + ['Group B'] * 4 + ['Group C'] * 4, | |
'metric': ['Received rate', 'Open rate', 'Click rate', 'Purchase rate'] * 3, | |
'alpha': 0.05 }) | |
.set_index(['metric', 'dimension']) | |
.sort_index() | |
.reindex(['Received rate', 'Open rate', 'Click rate', 'Purchase rate'], level=0)) | |
plot_experiment_results(df=example_data, | |
title='Example email campaign (α=0.10)', | |
sample_size=123456, | |
combine_axes=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment