Skip to content

Instantly share code, notes, and snippets.

@amankharwal
Created January 8, 2021 07:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amankharwal/d73b6d6eb6aa760180b21222b2cf982f to your computer and use it in GitHub Desktop.
Save amankharwal/d73b6d6eb6aa760180b21222b2cf982f to your computer and use it in GitHub Desktop.
def get_month(x) : return dt.datetime(x.year,x.month,1)
df['InvoiceMonth'] = df['InvoiceDate'].apply(get_month)
grouping = df.groupby('CustomerID')['InvoiceMonth']
df['CohortMonth'] = grouping.transform('min')
def get_month_int (dframe,column):
year = dframe[column].dt.year
month = dframe[column].dt.month
day = dframe[column].dt.day
return year, month , day
invoice_year,invoice_month,_ = get_month_int(df,'InvoiceMonth')
cohort_year,cohort_month,_ = get_month_int(df,'CohortMonth')
year_diff = invoice_year - cohort_year
month_diff = invoice_month - cohort_month
df['CohortIndex'] = year_diff * 12 + month_diff + 1
#Count monthly active customers from each cohort
grouping = df.groupby(['CohortMonth', 'CohortIndex'])
cohort_data = grouping['CustomerID'].apply(pd.Series.nunique)
# Return number of unique elements in the object.
cohort_data = cohort_data.reset_index()
cohort_counts = cohort_data.pivot(index='CohortMonth',columns='CohortIndex',values='CustomerID')
# Retention table
cohort_size = cohort_counts.iloc[:,0]
retention = cohort_counts.divide(cohort_size,axis=0) #axis=0 to ensure the divide along the row axis
retention.round(3) * 100 #to show the number as percentage
#Build the heatmap
plt.figure(figsize=(15, 8))
plt.title('Retention rates')
sns.heatmap(data=retention,annot = True,fmt = '.0%',vmin = 0.0,vmax = 0.5,cmap="BuPu_r")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment