This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
f, axes = plt.subplots(2, 3, figsize=(14, 7), sharex=False) | |
sns.distplot( joined_pandas["sessionCount"] , color="skyblue", ax=axes[0, 0]) | |
sns.distplot( joined_pandas["meanSongCount"] , color="olive", ax=axes[0, 1]) | |
sns.distplot( joined_pandas["sessionsFreqDay"] , color="gold", ax=axes[0, 2]) | |
# Skew handling | |
sns.distplot( np.log(joined_pandas["sessionCount"]) , color="skyblue", ax=axes[1, 0]) | |
sns.distplot( np.sqrt(joined_pandas["meanSongCount"]) , color="olive", ax=axes[1, 1]) | |
sns.distplot( np.sqrt(joined_pandas["sessionsFreqDay"]) , color="gold", ax=axes[1, 2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
joined = StringIndexer(inputCol='gender', outputCol='gender_idx')\ | |
.fit(joined)\ | |
.transform(joined) | |
joined = StringIndexer(inputCol='level', outputCol='level_idx')\ | |
.fit(joined)\ | |
.transform(joined) | |
joined = OneHotEncoderEstimator(inputCols=['gender_idx', 'level_idx'], | |
outputCols=['gender_dummy','level_dummy'])\ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
joined = user_features\ | |
.join(churn_data_summary, | |
on=['userId'], | |
how='left')\ | |
.join(user_engagement, | |
on=['userId'], | |
how='left')\ | |
.join(listen_freq, | |
on=['userId'], | |
how='left')\ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Show that we can do the same calculation above using SQL | |
data.createOrReplaceTempView('sparkify') | |
sub_query = """ | |
SELECT | |
userId, | |
sessionId, | |
max(itemInSession) as itemCount | |
FROM | |
sparkify |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
user_engagement = data\ | |
.groupBy('userId', 'sessionId')\ | |
.agg(F.max('itemInSession').alias('itemCount'))\ | |
.groupBy('userId')\ | |
.agg({"itemCount": "mean", "sessionId": "count"})\ | |
.withColumnRenamed('count(sessionId)', 'sessionCount')\ | |
.withColumnRenamed('avg(itemCount)', 'meanSongCount')\ | |
.orderBy('userId') | |
user_engagement.show(10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a new aggreated dataframe called listen_freq | |
# (stands for listening frequency) for each user | |
listen_freq = data.select('userId','sessionId', 'timeStamp')\ | |
.groupBy('userId','sessionId')\ | |
.agg(F.min('timeStamp').alias('sessionTime'))\ | |
.orderBy('userId', 'sessionId')\ | |
.groupBy('userId')\ | |
.agg(F.min('sessionTime').alias('minSessionTime'), | |
F.max('sessionTime').alias('maxSessionTime'), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# First let's have a look if we have any NAN values in our dataset | |
data.select([count(when(isnan(c), c)).alias(c) for c in data.columns]).head().asDict() | |
>> {'artist': 0, | |
'auth': 0, | |
'firstName': 0, | |
'gender': 0, | |
'itemInSession': 0, | |
'lastName': 0, | |
'length': 0, | |
'level': 0, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Read data into spark. | |
# Note: Ideally data should be in a schema supported format like parquet, | |
# which also supports partitioning, something very important while ingesting big data. | |
# Also data may be placed in a distributed filesystem like HDFS or in a cloud | |
# provider storage bucket like AWS S3 / Google Cloud Storage for faster reads. | |
# here we only read from local disk. | |
data = spark.read.json('mini_sparkify_event_data.json') | |
# How many user activity rows do we have? | |
data.count() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create a Spark session, if not there is or get an existing one | |
spark = SparkSession \ | |
.builder \ | |
.appName("Sparkify The music streaming platform churn detection") \ | |
.getOrCreate() | |
# Check the current Spark Config | |
spark.sparkContext.getConf().getAll() | |
>> [('spark.app.id', 'local-1569248217329'), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import libraries | |
from pyspark import SparkConf, SparkContext | |
from pyspark.sql import SparkSession, Window | |
from pyspark.sql.functions import count, when, isnan, isnull, desc_nulls_first, desc, \ | |
from_unixtime, col, dayofweek, dayofyear, hour, to_date, month | |
import pyspark.sql.functions as F | |
from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler, StandardScaler, MinMaxScaler | |
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier | |
# sc = SparkContext(appName="Project_workspace") |