Skip to content

Instantly share code, notes, and snippets.

@afraenkel
Created May 10, 2019 07:42
Show Gist options
  • Save afraenkel/659344aa8b3668e6d4a1b3a3bc274c7b to your computer and use it in GitHub Desktop.
Save afraenkel/659344aa8b3668e6d4a1b3a3bc274c7b to your computer and use it in GitHub Desktop.
PySpark: partitioning groups of events that occur within 90 days of previous.
import pyspark
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import Window
sc = SparkContext()
sqlc = SQLContext(sc)
# Test data
# E.g. IDs and server log-in days
L = '''
a,2016-01-01
a,2016-06-01
a,2016-02-01
a,2016-05-01
b,2016-10-01
b,2016-11-01
a,2016-11-01
a,2017-01-01
b,2017-06-01
a,2017-04-01
b,2017-04-01
'''.strip().split('\n')
L = list(map(lambda x:x.split(','), L))
## DataFrame
df = sqlc.createDataFrame(
L,
schema=T.StructType(
[
T.StructField('ID', T.StringType()),
T.StructField('DATE', T.StringType())
])
).withColumn('DATE', F.to_date('DATE').alias('DATE'))
# df.collect()
# [Row(ID='a', DATE=datetime.date(2016, 1, 1)),
# Row(ID='a', DATE=datetime.date(2016, 6, 1)),
# Row(ID='a', DATE=datetime.date(2016, 2, 1)),
# Row(ID='a', DATE=datetime.date(2016, 5, 1)),
# Row(ID='b', DATE=datetime.date(2016, 10, 1)),
# Row(ID='b', DATE=datetime.date(2016, 11, 1)),
# Row(ID='a', DATE=datetime.date(2016, 11, 1)),
# Row(ID='a', DATE=datetime.date(2017, 1, 1)),
# Row(ID='b', DATE=datetime.date(2017, 6, 1)),
# Row(ID='a', DATE=datetime.date(2017, 4, 1)),
# Row(ID='b', DATE=datetime.date(2017, 4, 1))]
### Partition data by ID (e.g. user)
## Attach a unique integer for each cluster of events in a partition. This partitions the partition.
## An event is in a cluster if it occurs within 90 days of the previous event.
wind = Window.partitionBy("ID").orderBy("DATE")
with_diffs = (
df
.withColumn('time_between',
F.datediff(F.col('DATE'), F.lag(F.col('DATE'), 1).over(wind))
)
.fillna(0, subset='time_between')
.withColumn('change', F.when(F.col('time_between') < 90, 0).otherwise(1))
.withColumn('cumsum', F.sum(F.col('change')).over(wind))
)
# keeps all the intermediate columns
# with_diffs.collect()
# [Row(ID='b', DATE=datetime.date(2016, 10, 1), time_between=0, change=0, cumsum=0),
# Row(ID='b', DATE=datetime.date(2016, 11, 1), time_between=31, change=0, cumsum=0),
# Row(ID='b', DATE=datetime.date(2017, 4, 1), time_between=151, change=1, cumsum=1),
# Row(ID='b', DATE=datetime.date(2017, 6, 1), time_between=61, change=0, cumsum=1),
# Row(ID='a', DATE=datetime.date(2016, 1, 1), time_between=0, change=0, cumsum=0),
# Row(ID='a', DATE=datetime.date(2016, 2, 1), time_between=31, change=0, cumsum=0),
# Row(ID='a', DATE=datetime.date(2016, 5, 1), time_between=90, change=1, cumsum=1),
# Row(ID='a', DATE=datetime.date(2016, 6, 1), time_between=31, change=0, cumsum=1),
# Row(ID='a', DATE=datetime.date(2016, 11, 1), time_between=153, change=1, cumsum=2),
# Row(ID='a', DATE=datetime.date(2017, 1, 1), time_between=61, change=0, cumsum=2),
# Row(ID='a', DATE=datetime.date(2017, 4, 1), time_between=90, change=1, cumsum=3)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment