Last active
August 18, 2019 10:48
-
-
Save FBosler/ba7a2c994640bd2f01959f294c55ff52 to your computer and use it in GitHub Desktop.
function generates cohorts, excel or figure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_cohort_analysis(df, metric, record_type='all', period_agg='quarterly', fig=True, size=10, save_fig=True): | |
""" | |
For metric use 'number_of_orders', 'number_of_items_bought' or 'total_order_value' | |
For record_type use 'all' or specific customer_type ['private','company','government'] | |
no_fig controlls the output of a figure, by default True (i.e. no figure) | |
""" | |
dataset = df.copy() | |
if record_type != 'all': | |
dataset = df[df.customer_type == record_type].copy() | |
# format dates (i.e. map customers into their cohort and orders into the respective order period) | |
if period_agg=='quarterly': | |
dataset['cohort'] = dataset['customer_first_order'].apply(lambda x: fortmat_quarter(x)) | |
dataset['order_period'] = dataset['order_date'].apply(lambda x: fortmat_quarter(x)) | |
elif period_agg=='monthly': | |
dataset['cohort'] = dataset['customer_first_order'].apply(lambda x: x.strftime('%Y-%m')) | |
dataset['order_period'] = dataset['order_date'].apply(lambda x: x.strftime('%Y-%m')) | |
else: | |
raise NotImplementedError(f'period_agg: {period_agg} is not implemented') | |
# generate cohorts | |
cohorts = _generate_cohorts(dataset,metric) | |
# generate new accounts data | |
cohort_group_size = dataset.groupby('cohort').agg({'customer': pd.Series.nunique}) | |
new_accs = cohort_group_size.reset_index() | |
new_accs.columns = ['cohort', 'New Accounts'] | |
# generate repeat data | |
repeat_perc, selection = _generate_repeat_percentages(dataset,metric) | |
# returns the data and does not plot anything | |
if not fig: | |
return (cohorts.T.join(new_accs.set_index('cohort')).fillna(0)) | |
#### Plot the Data #### | |
# create the figures grid | |
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 7), gridspec_kw={'width_ratios': (1, 14, 1)}) | |
sns.despine(left=True, bottom=True, right=True) | |
# plot new accounts | |
Accounts = sns.barplot(x="New Accounts", y='cohort', data=new_accs, palette="Blues", ax=ax1) | |
# plot retention matrix | |
Heatmap = sns.heatmap(cohorts.T, | |
cmap="Blues", | |
annot=True, | |
fmt=".0f", | |
annot_kws={"size": size}, | |
cbar=False, | |
yticklabels=False, | |
ax=ax2) | |
title = 'Retention Matrix for "{}" - for Account Type "{}"'.format(metric, record_type) | |
Heatmap.set_title(title) | |
Heatmap.yaxis.get_label().set_visible(False) | |
Heatmap.set_xlabel('order_period') | |
# plot repeat table | |
Repeats = sns.barplot(x=selection, y='cohort', data=repeat_perc, palette="Blues", ax=ax3) | |
# removes y-axis label | |
Repeats.yaxis.get_label().set_visible(False) | |
# removes y-axis tickl labels | |
Repeats.set(yticklabels=[]) | |
# removes y-axis ticks themselves | |
Repeats.set(yticks=[]) | |
vals = Repeats.get_xticks() | |
Repeats.set_xticklabels(['{:,.0f}%'.format(x * 100) for x in vals]) | |
# final layout touches | |
plt.tight_layout() | |
# saves the figure | |
if save_fig: | |
fig = Heatmap.get_figure() | |
fig.savefig(metric+'RetentionMatrix'+record_type+'.png', bbox_inches='tight', dpi=600) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment