Skip to content

Instantly share code, notes, and snippets.

@benmiroglio
Last active June 21, 2018 00:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benmiroglio/fc708e5905fad33b43adb9c90e38ebf4 to your computer and use it in GitHub Desktop.
Save benmiroglio/fc708e5905fad33b43adb9c90e38ebf4 to your computer and use it in GitHub Desktop.
retention-cookbook-example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# N-Week Retention\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pyspark\n",
"Util and imports"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"import datetime as dt\n",
"import pandas as pd\n",
"import pyspark.sql.types as st\n",
"import pyspark.sql.functions as F\n",
"\n",
"\n",
"udf = F.udf\n",
"\n",
"\n",
"PERIODS = {}\n",
"N_WEEKS = 6\n",
"for i in range(1, N_WEEKS + 1):\n",
" PERIODS[i] = {\n",
" 'start': i * 7,\n",
" 'end': i * 7 + 6\n",
" }\n",
" \n",
"\n",
"def date_diff(d1, d2, fmt='%Y%m%d'):\n",
" \"\"\"\n",
" Returns days elapsed from d2 to d1 as an integer\n",
" \n",
" Params:\n",
" d1 (str)\n",
" d2 (str)\n",
" fmt (str): format of d1 and d2 (must be the same)\n",
" \n",
" >>> date_diff('20170205', '20170201')\n",
" 4\n",
" \n",
" >>> date_diff('20170201', '20170205)\n",
" -4\n",
" \"\"\"\n",
" try:\n",
" return (pd.to_datetime(d1, format=fmt) - \n",
" pd.to_datetime(d2, format=fmt)).days\n",
" except:\n",
" return None\n",
" \n",
"\n",
"@udf(returnType=st.IntegerType())\n",
"def get_period(anchor, submission_date_s3):\n",
" \"\"\"\n",
" Given an anchor and a submission_date_s3,\n",
" returns what period a ping belongs to. This \n",
" is a spark UDF.\n",
" \n",
" Params:\n",
" anchor (col): anchor date\n",
" submission_date_s3 (col): a ping's submission_date to s3\n",
" \n",
" Global:\n",
" PERIODS (dict): defined globally based on n-week method\n",
" \n",
" Returns an integer indicating the retention period\n",
" \"\"\"\n",
" if anchor is not None:\n",
" diff = date_diff(submission_date_s3, anchor)\n",
" if diff >= 7: # exclude first 7 days\n",
" for period in sorted(PERIODS):\n",
" if diff <= PERIODS[period]['end']:\n",
" return period\n",
"\n",
"@udf(returnType=st.StringType())\n",
"def from_unixtime_handler(ut):\n",
" \"\"\"\n",
" Converts unix time (in days) to a string in %Y%m%d format.\n",
" This is spark UDF.\n",
" \n",
" Params:\n",
" ut (int): unix time in days\n",
" \n",
" Returns a date as a string if it is parsable by datetime, otherwise None\n",
" \"\"\"\n",
" if ut is not None:\n",
" try:\n",
" return (dt.datetime.fromtimestamp(ut * 24 * 60 * 60).strftime(\"%Y%m%d\"))\n",
" except:\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{1: {'end': 13, 'start': 7},\n",
" 2: {'end': 20, 'start': 14},\n",
" 3: {'end': 27, 'start': 21},\n",
" 4: {'end': 34, 'start': 28},\n",
" 5: {'end': 41, 'start': 35},\n",
" 6: {'end': 48, 'start': 42}}"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"PERIODS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" Load a 1% sample of `main_summary` on the release channel. We'll select `client_id`, `submission_date_s3` and `os`, with the intention of comparing retention between the three main OS's reported by users: `Darwin` (MacOS), `Windows_NT` and `Linux`."
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"ms = spark.sql(\"\"\"\n",
" select client_id, \n",
" submission_date_s3,\n",
" profile_creation_date,\n",
" os\n",
" from main_summary \n",
" where submission_date_s3 >= '20180401'\n",
" and submission_date_s3 <= '20180603'\n",
" and sample_id = '42'\n",
" and app_name = 'Firefox'\n",
" and normalized_channel = 'release'\n",
" and os in ('Darwin', 'Windows_NT', 'Linux')\n",
" \"\"\")\n",
"\n",
"PCD_CUTS = ('20180401', '20180415')\n",
"\n",
"ms = (\n",
" ms.withColumn(\"pcd\", from_unixtime_handler(\"profile_creation_date\")) # i.e. 17500 -> '20171130'\n",
" .filter(\"pcd >= '{}'\".format(PCD_CUTS[0]))\n",
" .filter(\"pcd <= '{}'\".format(PCD_CUTS[1]))\n",
" .withColumn(\"period\", get_period(\"pcd\", \"submission_date_s3\"))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"os_counts = (\n",
" ms\n",
" .groupby(\"os\")\n",
" .agg(F.countDistinct(\"client_id\").alias(\"total_clients\"))\n",
")\n",
"\n",
"weekly_counts = (\n",
" ms\n",
" .groupby(\"period\", \"os\")\n",
" .agg(F.countDistinct(\"client_id\").alias(\"n_week_clients\"))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"retention_by_os = (\n",
" weekly_counts\n",
" .join(os_counts, on='os')\n",
" .withColumn(\"retention\", F.col(\"n_week_clients\") / F.col(\"total_clients\"))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+------+--------------+-------------+-------------------+\n",
"| os|period|n_week_clients|total_clients| retention|\n",
"+----------+------+--------------+-------------+-------------------+\n",
"| Linux| 6| 1495| 22422|0.06667558647756668|\n",
"| Darwin| 6| 1288| 4734|0.27207435572454586|\n",
"|Windows_NT| 6| 29024| 124872|0.23243000832852842|\n",
"+----------+------+--------------+-------------+-------------------+\n",
"\n"
]
}
],
"source": [
"retention_by_os.filter(\"period = 6\").show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pure SQL"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"ms = spark.sql(\"\"\"\n",
" SELECT client_id,\n",
" submission_date_s3,\n",
" from_unixtime(profile_creation_date * 60 * 60 * 24, \"yyyy-MM-dd\") as pcd,\n",
" datediff(CONCAT(SUBSTR(submission_date_s3, 1, 4), '-',\n",
" SUBSTR(submission_date_s3, 5, 2), '-',\n",
" SUBSTR(Submission_date_s3, 7, 2)), \n",
" from_unixtime(profile_creation_date * 60 * 60 * 24, \"yyyy-MM-dd\")) as diff,\n",
" os\n",
" FROM main_summary\n",
" WHERE submission_date_s3 >= '20180401'\n",
" AND submission_date_s3 <= '20180603'\n",
" AND sample_id = '42'\n",
" AND app_name = 'Firefox'\n",
" AND normalized_channel = 'release'\n",
" AND os in ('Darwin', 'Windows_NT', 'Linux')\n",
" AND from_unixtime(profile_creation_date * 60 * 60 * 24, \"yyyy-MM-dd\") \n",
" BETWEEN '2018-04-01' and '2018-04-15'\n",
"\"\"\")\n",
"\n",
"ms.registerTempTable('ms')\n",
"\n",
"week_counts = spark.sql(\"\"\"\n",
"SELECT os,\n",
" case\n",
" when diff < 7 then 0\n",
" when diff <= 13 then 1\n",
" when diff <= 20 then 2\n",
" when diff <= 27 then 3\n",
" when diff <= 34 then 4\n",
" when diff <= 41 then 5\n",
" when diff <= 48 then 6\n",
" else null\n",
" end as period,\n",
" COUNT(DISTINCT client_id) as n_clients\n",
"from ms\n",
"GROUP BY 1, 2\n",
"\"\"\")\n",
"\n",
"week_counts.registerTempTable(\"week_counts\")\n",
"\n",
"retention = spark.sql(\"\"\"\n",
"SELECT l.os,\n",
" period,\n",
" n_clients,\n",
" r.total_clients,\n",
" n_clients / r.total_clients as retention\n",
"FROM week_counts l\n",
"JOIN (\n",
" SELECT os,\n",
" COUNT(DISTINCT client_id) as total_clients\n",
" FROM ms\n",
" GROUP BY 1) r\n",
"ON l.os = r.os\n",
"WHERE period = 6\n",
"\"\"\")\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+------+---------+-------------+-------------------+\n",
"| os|period|n_clients|total_clients| retention|\n",
"+----------+------+---------+-------------+-------------------+\n",
"| Linux| 6| 1495| 22422|0.06667558647756668|\n",
"| Darwin| 6| 1288| 4734|0.27207435572454586|\n",
"|Windows_NT| 6| 29024| 124872|0.23243000832852842|\n",
"+----------+------+---------+-------------+-------------------+\n",
"\n"
]
}
],
"source": [
"retention.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
# coding: utf-8
# # N-Week Retention
#
#
# # Pyspark
# Util and imports
# In[57]:
import datetime as dt
import pandas as pd
import pyspark.sql.types as st
import pyspark.sql.functions as F
udf = F.udf
PERIODS = {}
N_WEEKS = 6
for i in range(1, N_WEEKS + 1):
PERIODS[i] = {
'start': i * 7,
'end': i * 7 + 6
}
def date_diff(d1, d2, fmt='%Y%m%d'):
"""
Returns days elapsed from d2 to d1 as an integer
Params:
d1 (str)
d2 (str)
fmt (str): format of d1 and d2 (must be the same)
>>> date_diff('20170205', '20170201')
4
>>> date_diff('20170201', '20170205)
-4
"""
try:
return (pd.to_datetime(d1, format=fmt) -
pd.to_datetime(d2, format=fmt)).days
except:
return None
@udf(returnType=st.IntegerType())
def get_period(anchor, submission_date_s3):
"""
Given an anchor and a submission_date_s3,
returns what period a ping belongs to. This
is a spark UDF.
Params:
anchor (col): anchor date
submission_date_s3 (col): a ping's submission_date to s3
Global:
PERIODS (dict): defined globally based on n-week method
Returns an integer indicating the retention period
"""
if anchor is not None:
diff = date_diff(submission_date_s3, anchor)
if diff >= 7: # exclude first 7 days
for period in sorted(PERIODS):
if diff <= PERIODS[period]['end']:
return period
@udf(returnType=st.StringType())
def from_unixtime_handler(ut):
"""
Converts unix time (in days) to a string in %Y%m%d format.
This is spark UDF.
Params:
ut (int): unix time in days
Returns a date as a string if it is parsable by datetime, otherwise None
"""
if ut is not None:
try:
return (dt.datetime.fromtimestamp(ut * 24 * 60 * 60).strftime("%Y%m%d"))
except:
return None
# In[58]:
PERIODS
# Load a 1% sample of `main_summary` on the release channel. We'll select `client_id`, `submission_date_s3` and `os`, with the intention of comparing retention between the three main OS's reported by users: `Darwin` (MacOS), `Windows_NT` and `Linux`.
# In[59]:
ms = spark.sql("""
select client_id,
submission_date_s3,
profile_creation_date,
os
from main_summary
where submission_date_s3 >= '20180401'
and submission_date_s3 <= '20180603'
and sample_id = '42'
and app_name = 'Firefox'
and normalized_channel = 'release'
and os in ('Darwin', 'Windows_NT', 'Linux')
""")
PCD_CUTS = ('20180401', '20180415')
ms = (
ms.withColumn("pcd", from_unixtime_handler("profile_creation_date")) # i.e. 17500 -> '20171130'
.filter("pcd >= '{}'".format(PCD_CUTS[0]))
.filter("pcd <= '{}'".format(PCD_CUTS[1]))
.withColumn("period", get_period("pcd", "submission_date_s3"))
)
# In[60]:
os_counts = (
ms
.groupby("os")
.agg(F.countDistinct("client_id").alias("total_clients"))
)
weekly_counts = (
ms
.groupby("period", "os")
.agg(F.countDistinct("client_id").alias("n_week_clients"))
)
# In[61]:
retention_by_os = (
weekly_counts
.join(os_counts, on='os')
.withColumn("retention", F.col("n_week_clients") / F.col("total_clients"))
)
# In[68]:
retention_by_os.filter("period = 6").show()
# # Pure SQL
# In[71]:
ms = spark.sql("""
SELECT client_id,
submission_date_s3,
from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd") as pcd,
datediff(CONCAT(SUBSTR(submission_date_s3, 1, 4), '-',
SUBSTR(submission_date_s3, 5, 2), '-',
SUBSTR(Submission_date_s3, 7, 2)),
from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd")) as diff,
os
FROM main_summary
WHERE submission_date_s3 >= '20180401'
AND submission_date_s3 <= '20180603'
AND sample_id = '42'
AND app_name = 'Firefox'
AND normalized_channel = 'release'
AND os in ('Darwin', 'Windows_NT', 'Linux')
AND from_unixtime(profile_creation_date * 60 * 60 * 24, "yyyy-MM-dd")
BETWEEN '2018-04-01' and '2018-04-15'
""")
ms.registerTempTable('ms')
week_counts = spark.sql("""
SELECT os,
case
when diff < 7 then 0
when diff <= 13 then 1
when diff <= 20 then 2
when diff <= 27 then 3
when diff <= 34 then 4
when diff <= 41 then 5
when diff <= 48 then 6
else null
end as period,
COUNT(DISTINCT client_id) as n_clients
from ms
GROUP BY 1, 2
""")
week_counts.registerTempTable("week_counts")
retention = spark.sql("""
SELECT l.os,
period,
n_clients,
r.total_clients,
n_clients / r.total_clients as retention
FROM week_counts l
JOIN (
SELECT os,
COUNT(DISTINCT client_id) as total_clients
FROM ms
GROUP BY 1) r
ON l.os = r.os
WHERE period = 6
""")
# In[72]:
retention.show()
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment