Skip to content

Instantly share code, notes, and snippets.

@benmiroglio
Last active June 21, 2018 00:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benmiroglio/fc708e5905fad33b43adb9c90e38ebf4 to your computer and use it in GitHub Desktop.
Save benmiroglio/fc708e5905fad33b43adb9c90e38ebf4 to your computer and use it in GitHub Desktop.
retention-cookbook-example
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
# coding: utf-8
# # N-Week Retention
#
#
# # Pyspark
# Util and imports
# In[57]:
import datetime as dt
import pandas as pd
import pyspark.sql.types as st
import pyspark.sql.functions as F
udf = F.udf
PERIODS = {}
N_WEEKS = 6
for i in range(1, N_WEEKS + 1):
PERIODS[i] = {
'start': i * 7,
'end': i * 7 + 6
}
def date_diff(d1, d2, fmt='%Y%m%d'):
"""
Returns days elapsed from d2 to d1 as an integer
Params:
d1 (str)
d2 (str)
fmt (str): format of d1 and d2 (must be the same)
>>> date_diff('20170205', '20170201')
4
>>> date_diff('20170201', '20170205)
-4
"""
try:
return (pd.to_datetime(d1, format=fmt) -
pd.to_datetime(d2, format=fmt)).days
except:
return None
@udf(returnType=st.IntegerType())
def get_period(anchor, submission_date_s3):
"""
Given an anchor and a submission_date_s3,
returns what period a ping belongs to. This
is a spark UDF.
Params:
anchor (col): anchor date
submission_date_s3 (col): a ping's submission_date to s3
Global:
PERIODS (dict): defined globally based on n-week method
Returns an integer indicating the retention period
"""
if anchor is not None:
diff = date_diff(submission_date_s3, anchor)
if diff >= 7: # exclude first 7 days
for period in sorted(PERIODS):
if diff <= PERIODS[period]['end']:
return period
@udf(returnType=st.StringType())
def from_unixtime_handler(ut):
"""
Converts unix time (in days) to a string in %Y%m%d format.
This is spark UDF.
Params:
ut (int): unix time in days
Returns a date as a string if it is parsable by datetime, otherwise None
"""
if ut is not None:
try:
return (dt.datetime.fromtimestamp(ut * 24 * 60 * 60).strftime("%Y%m%d"))
except:
return None
# In[58]:
PERIODS
# Load a 1% sample of `main_summary` on the release channel. We'll select `client_id`, `submission_date_s3` and `os`, with the intention of comparing retention between the three main OS's reported by users: `Darwin` (MacOS), `Windows_NT` and `Linux`.
# In[59]:
ms = spark.sql("""
select client_id,
submission_date_s3,
profile_creation_date,
os
from main_summary
where submission_date_s3 >= '20180401'
and submission_date_s3 <= '20180603'
and sample_id = '42'
and app_name = 'Firefox'
and normalized_channel = 'release'
and os in ('Darwin', 'Windows_NT', 'Linux')
""")
PCD_CUTS = ('20180401', '20180415')
ms = (
ms.withColumn("pcd", from_unixtime_handler("profile_creation_date")) # i.e. 17500 -> '20171130'
.filter("pcd >= '{}'".format(PCD_CUTS[0]))
.filter("pcd <= '{}'".format(PCD_CUTS[1]))
.withColumn("period", get_period("pcd", "submission_date_s3"))
)
# In[60]:
os_counts = (
ms
.groupby("os")
.agg(F.countDistinct("client_id").alias("total_clients"))
)
weekly_counts = (
ms
.groupby("period", "os")
.agg(F.countDistinct("client_id").alias("n_week_clients"))
)
# In[61]:
retention_by_os = (
weekly_counts
.join(os_counts, on='os')
.withColumn("retention", F.col("n_week_clients") / F.col("total_clients"))
)
# In[68]:
retention_by_os.filter("period = 6").show()
# # Pure SQL
# In[71]:
ms = spark.sql("""
SELECT client_id,
submission_date_s3,
from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd") as pcd,
datediff(CONCAT(SUBSTR(submission_date_s3, 1, 4), '-',
SUBSTR(submission_date_s3, 5, 2), '-',
SUBSTR(Submission_date_s3, 7, 2)),
from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd")) as diff,
os
FROM main_summary
WHERE submission_date_s3 >= '20180401'
AND submission_date_s3 <= '20180603'
AND sample_id = '42'
AND app_name = 'Firefox'
AND normalized_channel = 'release'
AND os in ('Darwin', 'Windows_NT', 'Linux')
AND from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd")
BETWEEN '2018-04-01' and '2018-04-15'
""")
ms.registerTempTable('ms')
week_counts = spark.sql("""
SELECT os,
case
when diff < 7 then 0
when diff <= 13 then 1
when diff <= 20 then 2
when diff <= 27 then 3
when diff <= 34 then 4
when diff <= 41 then 5
when diff <= 48 then 6
else null
end as period,
COUNT(DISTINCT client_id) as n_clients
from ms
GROUP BY 1, 2
""")
week_counts.registerTempTable("week_counts")
retention = spark.sql("""
SELECT l.os,
period,
n_clients,
r.total_clients,
n_clients / r.total_clients as retention
FROM week_counts l
JOIN (
SELECT os,
COUNT(DISTINCT client_id) as total_clients
FROM ms
GROUP BY 1) r
ON l.os = r.os
WHERE period = 6
""")
# In[72]:
retention.show()
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment