-
-
Save benmiroglio/fc708e5905fad33b43adb9c90e38ebf4 to your computer and use it in GitHub Desktop.
retention-cookbook-example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# # N-Week Retention | |
# | |
# | |
# # Pyspark | |
# Util and imports | |
# In[57]: | |
import datetime as dt | |
import pandas as pd | |
import pyspark.sql.types as st | |
import pyspark.sql.functions as F | |
udf = F.udf | |
PERIODS = {} | |
N_WEEKS = 6 | |
for i in range(1, N_WEEKS + 1): | |
PERIODS[i] = { | |
'start': i * 7, | |
'end': i * 7 + 6 | |
} | |
def date_diff(d1, d2, fmt='%Y%m%d'): | |
""" | |
Returns days elapsed from d2 to d1 as an integer | |
Params: | |
d1 (str) | |
d2 (str) | |
fmt (str): format of d1 and d2 (must be the same) | |
>>> date_diff('20170205', '20170201') | |
4 | |
>>> date_diff('20170201', '20170205) | |
-4 | |
""" | |
try: | |
return (pd.to_datetime(d1, format=fmt) - | |
pd.to_datetime(d2, format=fmt)).days | |
except: | |
return None | |
@udf(returnType=st.IntegerType()) | |
def get_period(anchor, submission_date_s3): | |
""" | |
Given an anchor and a submission_date_s3, | |
returns what period a ping belongs to. This | |
is a spark UDF. | |
Params: | |
anchor (col): anchor date | |
submission_date_s3 (col): a ping's submission_date to s3 | |
Global: | |
PERIODS (dict): defined globally based on n-week method | |
Returns an integer indicating the retention period | |
""" | |
if anchor is not None: | |
diff = date_diff(submission_date_s3, anchor) | |
if diff >= 7: # exclude first 7 days | |
for period in sorted(PERIODS): | |
if diff <= PERIODS[period]['end']: | |
return period | |
@udf(returnType=st.StringType()) | |
def from_unixtime_handler(ut): | |
""" | |
Converts unix time (in days) to a string in %Y%m%d format. | |
This is spark UDF. | |
Params: | |
ut (int): unix time in days | |
Returns a date as a string if it is parsable by datetime, otherwise None | |
""" | |
if ut is not None: | |
try: | |
return (dt.datetime.fromtimestamp(ut * 24 * 60 * 60).strftime("%Y%m%d")) | |
except: | |
return None | |
# In[58]: | |
PERIODS | |
# Load a 1% sample of `main_summary` on the release channel. We'll select `client_id`, `submission_date_s3` and `os`, with the intention of comparing retention between the three main OS's reported by users: `Darwin` (MacOS), `Windows_NT` and `Linux`. | |
# In[59]: | |
ms = spark.sql(""" | |
select client_id, | |
submission_date_s3, | |
profile_creation_date, | |
os | |
from main_summary | |
where submission_date_s3 >= '20180401' | |
and submission_date_s3 <= '20180603' | |
and sample_id = '42' | |
and app_name = 'Firefox' | |
and normalized_channel = 'release' | |
and os in ('Darwin', 'Windows_NT', 'Linux') | |
""") | |
PCD_CUTS = ('20180401', '20180415') | |
ms = ( | |
ms.withColumn("pcd", from_unixtime_handler("profile_creation_date")) # i.e. 17500 -> '20171130' | |
.filter("pcd >= '{}'".format(PCD_CUTS[0])) | |
.filter("pcd <= '{}'".format(PCD_CUTS[1])) | |
.withColumn("period", get_period("pcd", "submission_date_s3")) | |
) | |
# In[60]: | |
os_counts = ( | |
ms | |
.groupby("os") | |
.agg(F.countDistinct("client_id").alias("total_clients")) | |
) | |
weekly_counts = ( | |
ms | |
.groupby("period", "os") | |
.agg(F.countDistinct("client_id").alias("n_week_clients")) | |
) | |
# In[61]: | |
retention_by_os = ( | |
weekly_counts | |
.join(os_counts, on='os') | |
.withColumn("retention", F.col("n_week_clients") / F.col("total_clients")) | |
) | |
# In[68]: | |
retention_by_os.filter("period = 6").show() | |
# # Pure SQL | |
# In[71]: | |
ms = spark.sql(""" | |
SELECT client_id, | |
submission_date_s3, | |
from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd") as pcd, | |
datediff(CONCAT(SUBSTR(submission_date_s3, 1, 4), '-', | |
SUBSTR(submission_date_s3, 5, 2), '-', | |
SUBSTR(Submission_date_s3, 7, 2)), | |
from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd")) as diff, | |
os | |
FROM main_summary | |
WHERE submission_date_s3 >= '20180401' | |
AND submission_date_s3 <= '20180603' | |
AND sample_id = '42' | |
AND app_name = 'Firefox' | |
AND normalized_channel = 'release' | |
AND os in ('Darwin', 'Windows_NT', 'Linux') | |
AND from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd") | |
BETWEEN '2018-04-01' and '2018-04-15' | |
""") | |
ms.registerTempTable('ms') | |
week_counts = spark.sql(""" | |
SELECT os, | |
case | |
when diff < 7 then 0 | |
when diff <= 13 then 1 | |
when diff <= 20 then 2 | |
when diff <= 27 then 3 | |
when diff <= 34 then 4 | |
when diff <= 41 then 5 | |
when diff <= 48 then 6 | |
else null | |
end as period, | |
COUNT(DISTINCT client_id) as n_clients | |
from ms | |
GROUP BY 1, 2 | |
""") | |
week_counts.registerTempTable("week_counts") | |
retention = spark.sql(""" | |
SELECT l.os, | |
period, | |
n_clients, | |
r.total_clients, | |
n_clients / r.total_clients as retention | |
FROM week_counts l | |
JOIN ( | |
SELECT os, | |
COUNT(DISTINCT client_id) as total_clients | |
FROM ms | |
GROUP BY 1) r | |
ON l.os = r.os | |
WHERE period = 6 | |
""") | |
# In[72]: | |
retention.show() | |
# In[ ]: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment