Skip to content

Instantly share code, notes, and snippets.

@FBosler
Last active August 18, 2019 10:48
Show Gist options
  • Save FBosler/ba7a2c994640bd2f01959f294c55ff52 to your computer and use it in GitHub Desktop.
Save FBosler/ba7a2c994640bd2f01959f294c55ff52 to your computer and use it in GitHub Desktop.
function generates cohorts, excel or figure
def generate_cohort_analysis(df, metric, record_type='all', period_agg='quarterly', fig=True, size=10, save_fig=True):
"""
For metric use 'number_of_orders', 'number_of_items_bought' or 'total_order_value'
For record_type use 'all' or specific customer_type ['private','company','government']
no_fig controlls the output of a figure, by default True (i.e. no figure)
"""
dataset = df.copy()
if record_type != 'all':
dataset = df[df.customer_type == record_type].copy()
# format dates (i.e. map customers into their cohort and orders into the respective order period)
if period_agg=='quarterly':
dataset['cohort'] = dataset['customer_first_order'].apply(lambda x: fortmat_quarter(x))
dataset['order_period'] = dataset['order_date'].apply(lambda x: fortmat_quarter(x))
elif period_agg=='monthly':
dataset['cohort'] = dataset['customer_first_order'].apply(lambda x: x.strftime('%Y-%m'))
dataset['order_period'] = dataset['order_date'].apply(lambda x: x.strftime('%Y-%m'))
else:
raise NotImplementedError(f'period_agg: {period_agg} is not implemented')
# generate cohorts
cohorts = _generate_cohorts(dataset,metric)
# generate new accounts data
cohort_group_size = dataset.groupby('cohort').agg({'customer': pd.Series.nunique})
new_accs = cohort_group_size.reset_index()
new_accs.columns = ['cohort', 'New Accounts']
# generate repeat data
repeat_perc, selection = _generate_repeat_percentages(dataset,metric)
# returns the data and does not plot anything
if not fig:
return (cohorts.T.join(new_accs.set_index('cohort')).fillna(0))
#### Plot the Data ####
# create the figures grid
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 7), gridspec_kw={'width_ratios': (1, 14, 1)})
sns.despine(left=True, bottom=True, right=True)
# plot new accounts
Accounts = sns.barplot(x="New Accounts", y='cohort', data=new_accs, palette="Blues", ax=ax1)
# plot retention matrix
Heatmap = sns.heatmap(cohorts.T,
cmap="Blues",
annot=True,
fmt=".0f",
annot_kws={"size": size},
cbar=False,
yticklabels=False,
ax=ax2)
title = 'Retention Matrix for "{}" - for Account Type "{}"'.format(metric, record_type)
Heatmap.set_title(title)
Heatmap.yaxis.get_label().set_visible(False)
Heatmap.set_xlabel('order_period')
# plot repeat table
Repeats = sns.barplot(x=selection, y='cohort', data=repeat_perc, palette="Blues", ax=ax3)
# removes y-axis label
Repeats.yaxis.get_label().set_visible(False)
# removes y-axis tickl labels
Repeats.set(yticklabels=[])
# removes y-axis ticks themselves
Repeats.set(yticks=[])
vals = Repeats.get_xticks()
Repeats.set_xticklabels(['{:,.0f}%'.format(x * 100) for x in vals])
# final layout touches
plt.tight_layout()
# saves the figure
if save_fig:
fig = Heatmap.get_figure()
fig.savefig(metric+'RetentionMatrix'+record_type+'.png', bbox_inches='tight', dpi=600)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment