Skip to content

Instantly share code, notes, and snippets.

@sserrano44
Created March 23, 2022 21:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sserrano44/8f56e964c426abcee4b278c0dba72f07 to your computer and use it in GitHub Desktop.
Save sserrano44/8f56e964c426abcee4b278c0dba72f07 to your computer and use it in GitHub Desktop.
Co-hort analysis
"""
Co-hort analysis following http://www.gregreda.com/2015/08/23/cohort-analysis-with-python/
"""
import datetime
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
pd.set_option('max_columns', 50)
mpl.rcParams['lines.linewidth'] = 2
def cohort_period(df):
"""
Creates a `CohortPeriod` column, which is the Nth period based on the user's first purchase.
Example
-------
Say you want to get the 3rd month for every user:
df.sort(['UserId', 'OrderTime', inplace=True)
df = df.groupby('UserId').apply(cohort_period)
df[df.CohortPeriod == 3]
"""
df['CohortPeriod'] = np.arange(len(df)) + 1
return df
def run(fname):
df = pd.read_csv(fname)
df.head()
df['OrderDate'] = df['Date Created'].apply(lambda x: datetime.datetime.strptime(x[:10], '%Y-%m-%d'))
df['OrderPeriod'] = df.OrderDate.apply(lambda x: x.strftime('%Y-%m'))
df.set_index('Account ID', inplace=True)
df['CohortGroup'] = df.groupby(level=0)['OrderDate'].min().apply(lambda x: x.strftime('%Y-%m'))
df.reset_index(inplace=True)
grouped = df.groupby(['CohortGroup', 'OrderPeriod'])
# count the unique users, orders, and total revenue per Group + Period
cohorts = grouped.agg({'Account ID': pd.Series.nunique,
'ID': pd.Series.nunique,
'Total Usd': np.sum})
# make the column names more meaningful
cohorts.rename(columns={'Account ID': 'TotalUsers',
'ID': 'TotalOrders'}, inplace=True)
cohorts.head()
cohorts = cohorts.groupby(level=0).apply(cohort_period)
cohorts.head()
"""
Validation
x = df[(df.CohortGroup == '2009-01') & (df.OrderPeriod == '2009-01')]
y = cohorts.ix[('2009-01', '2009-01')]
assert(x['UserId'].nunique() == y['TotalUsers'])
assert(x['TotalCharges'].sum().round(2) == y['TotalCharges'].round(2))
assert(x['OrderId'].nunique() == y['TotalOrders'])
x = df[(df.CohortGroup == '2009-01') & (df.OrderPeriod == '2009-09')]
y = cohorts.ix[('2009-01', '2009-09')]
assert(x['UserId'].nunique() == y['TotalUsers'])
assert(x['TotalCharges'].sum().round(2) == y['TotalCharges'].round(2))
assert(x['OrderId'].nunique() == y['TotalOrders'])
x = df[(df.CohortGroup == '2009-05') & (df.OrderPeriod == '2009-09')]
y = cohorts.ix[('2009-05', '2009-09')]
assert(x['UserId'].nunique() == y['TotalUsers'])
assert(x['TotalCharges'].sum().round(2) == y['TotalCharges'].round(2))
assert(x['OrderId'].nunique() == y['TotalOrders'])
"""
# reindex the DataFrame
cohorts.reset_index(inplace=True)
cohorts.set_index(['CohortGroup', 'CohortPeriod'], inplace=True)
# create a Series holding the total size of each CohortGroup
cohort_group_size = cohorts['TotalUsers'].groupby(level=0).first()
cohort_group_size.head()
user_retention = cohorts['TotalUsers'].unstack(0).divide(cohort_group_size, axis=1)
print(user_retention.head(10))
pltname = fname.replace('.csv', '')
user_retention.plot(figsize=(10,5))
plt.title('Cohorts: User Retention - %s' % fname)
plt.xticks(np.arange(1, 12.1, 1))
plt.xlim(1, 12)
plt.ylabel('% of Cohort Purchasing')
plt.ylim(0,1)
plt.savefig('%s.png' % pltname)
sns.set(style='white')
plt.figure(figsize=(12, 8))
plt.title('Cohorts: User Retention - %s' % fname)
sns.heatmap(user_retention.T, mask=user_retention.T.isnull(), annot=True, fmt='.0%');
plt.savefig('%s-heatmap.png' % pltname)
if __name__ == "__main__":
run(sys.argv[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment