Skip to content

Instantly share code, notes, and snippets.

@ThaiDat
Last active June 25, 2024 06:31
Show Gist options
  • Save ThaiDat/81c3662801aa8410a65b94f3c993c377 to your computer and use it in GitHub Desktop.
Save ThaiDat/81c3662801aa8410a65b94f3c993c377 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "26357b1e",
"metadata": {
"papermill": {
"duration": 0.125431,
"end_time": "2022-04-18T16:05:28.535018",
"exception": false,
"start_time": "2022-04-18T16:05:28.409587",
"status": "completed"
},
"tags": []
},
"source": [
"# PySpark Demo"
]
},
{
"cell_type": "markdown",
"id": "527fc531",
"metadata": {
"papermill": {
"duration": 0.12404,
"end_time": "2022-04-18T16:05:28.789185",
"exception": false,
"start_time": "2022-04-18T16:05:28.665145",
"status": "completed"
},
"tags": []
},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b2bd8328",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:05:29.043014Z",
"iopub.status.busy": "2022-04-18T16:05:29.037694Z",
"iopub.status.idle": "2022-04-18T16:06:14.896025Z",
"shell.execute_reply": "2022-04-18T16:06:14.895187Z",
"shell.execute_reply.started": "2022-04-18T14:25:04.978612Z"
},
"papermill": {
"duration": 45.983882,
"end_time": "2022-04-18T16:06:14.896217",
"exception": false,
"start_time": "2022-04-18T16:05:28.912335",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting pyspark\r\n",
" Downloading pyspark-3.2.1.tar.gz (281.4 MB)\r\n",
" |████████████████████████████████| 281.4 MB 31 kB/s \r\n",
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l-\b \b\\\b \bdone\r\n",
"\u001b[?25hRequirement already satisfied: py4j==0.10.9.3 in /opt/conda/lib/python3.7/site-packages (from pyspark) (0.10.9.3)\r\n",
"Building wheels for collected packages: pyspark\r\n",
" Building wheel for pyspark (setup.py) ... \u001b[?25l-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \b|\b \b/\b \b-\b \b\\\b \bdone\r\n",
"\u001b[?25h Created wheel for pyspark: filename=pyspark-3.2.1-py2.py3-none-any.whl size=281853642 sha256=35e0e2a73e1607e6ad7d7747f4ac7772d708c8d326820dc3d3ac2f8700284971\r\n",
" Stored in directory: /root/.cache/pip/wheels/9f/f5/07/7cd8017084dce4e93e84e92efd1e1d5334db05f2e83bcef74f\r\n",
"Successfully built pyspark\r\n",
"Installing collected packages: pyspark\r\n",
"Successfully installed pyspark-3.2.1\r\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\r\n"
]
}
],
"source": [
"!pip install pyspark"
]
},
{
"cell_type": "markdown",
"id": "8097f07f",
"metadata": {
"papermill": {
"duration": 0.249348,
"end_time": "2022-04-18T16:06:15.396520",
"exception": false,
"start_time": "2022-04-18T16:06:15.147172",
"status": "completed"
},
"tags": []
},
"source": [
"`SparkContext` served as an entry point before version 2.0. Spark 2.0 introduced the `SparkSession` class, which is a centralised class that contains all of the contexts that existed before the 2.0 update (SQLContext and Hive Context etc.). Note that `SparkContext` has not been completely replaced by SparkSession; certain functions are still present and are used in Spark 2.0 and later."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "932a9587",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:06:15.895836Z",
"iopub.status.busy": "2022-04-18T16:06:15.895188Z",
"iopub.status.idle": "2022-04-18T16:06:22.067809Z",
"shell.execute_reply": "2022-04-18T16:06:22.068653Z",
"shell.execute_reply.started": "2022-04-18T14:26:05.605435Z"
},
"papermill": {
"duration": 6.426796,
"end_time": "2022-04-18T16:06:22.068967",
"exception": false,
"start_time": "2022-04-18T16:06:15.642171",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: An illegal reflective access operation has occurred\n",
"WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/opt/conda/lib/python3.7/site-packages/pyspark/jars/spark-unsafe_2.12-3.2.1.jar) to constructor java.nio.DirectByteBuffer(long,int)\n",
"WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform\n",
"WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n",
"WARNING: All illegal access operations will be denied in a future release\n",
"Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
"22/04/18 16:06:18 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <p><b>SparkSession - in-memory</b></p>\n",
" \n",
" <div>\n",
" <p><b>SparkContext</b></p>\n",
"\n",
" <p><a href=\"http://2f64637d1284:4040\">Spark UI</a></p>\n",
"\n",
" <dl>\n",
" <dt>Version</dt>\n",
" <dd><code>v3.2.1</code></dd>\n",
" <dt>Master</dt>\n",
" <dd><code>local[*]</code></dd>\n",
" <dt>AppName</dt>\n",
" <dd><code>Spark Demo</code></dd>\n",
" </dl>\n",
" </div>\n",
" \n",
" </div>\n",
" "
],
"text/plain": [
"<pyspark.sql.session.SparkSession at 0x7fbe894c2b90>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pyspark.sql import SparkSession\n",
"\n",
"\n",
"spark = SparkSession.builder.appName('Spark Demo').master('local[*]').getOrCreate()\n",
"spark"
]
},
{
"cell_type": "markdown",
"id": "b93cc0e4",
"metadata": {
"papermill": {
"duration": 0.247276,
"end_time": "2022-04-18T16:06:22.569197",
"exception": false,
"start_time": "2022-04-18T16:06:22.321921",
"status": "completed"
},
"tags": []
},
"source": [
"## Load"
]
},
{
"cell_type": "markdown",
"id": "d7e07b61",
"metadata": {
"papermill": {
"duration": 0.248472,
"end_time": "2022-04-18T16:06:23.066790",
"exception": false,
"start_time": "2022-04-18T16:06:22.818318",
"status": "completed"
},
"tags": []
},
"source": [
"`SparkSession` object provides `read` as a property that returns a `DataFrameReader` that can be used to read data in as a `DataFrame`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "055074cb",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:06:23.567222Z",
"iopub.status.busy": "2022-04-18T16:06:23.566214Z",
"iopub.status.idle": "2022-04-18T16:06:41.698827Z",
"shell.execute_reply": "2022-04-18T16:06:41.699644Z",
"shell.execute_reply.started": "2022-04-18T14:26:11.566310Z"
},
"papermill": {
"duration": 18.385275,
"end_time": "2022-04-18T16:06:41.699910",
"exception": false,
"start_time": "2022-04-18T16:06:23.314635",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 1:> (0 + 4) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- step: integer (nullable = true)\n",
" |-- type: string (nullable = true)\n",
" |-- amount: double (nullable = true)\n",
" |-- nameOrig: string (nullable = true)\n",
" |-- oldbalanceOrg: double (nullable = true)\n",
" |-- newbalanceOrig: double (nullable = true)\n",
" |-- nameDest: string (nullable = true)\n",
" |-- oldbalanceDest: double (nullable = true)\n",
" |-- newbalanceDest: double (nullable = true)\n",
" |-- isFraud: integer (nullable = true)\n",
" |-- isFlaggedFraud: integer (nullable = true)\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"data_path = '../input/fraudulent-transactions-data/Fraud.csv'\n",
"df = spark.read.csv(data_path, header=True, inferSchema=True)\n",
"df.printSchema()"
]
},
{
"cell_type": "markdown",
"id": "c80cc782",
"metadata": {
"papermill": {
"duration": 0.246857,
"end_time": "2022-04-18T16:06:42.202205",
"exception": false,
"start_time": "2022-04-18T16:06:41.955348",
"status": "completed"
},
"tags": []
},
"source": [
"`inferSchema` requires a pre-read of the data just to infer the schema. We can define the schema by ourselves to achive better performance."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "37127ed9",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:06:42.704055Z",
"iopub.status.busy": "2022-04-18T16:06:42.701230Z",
"iopub.status.idle": "2022-04-18T16:06:42.846874Z",
"shell.execute_reply": "2022-04-18T16:06:42.846278Z",
"shell.execute_reply.started": "2022-04-18T14:26:26.542279Z"
},
"papermill": {
"duration": 0.401303,
"end_time": "2022-04-18T16:06:42.847040",
"exception": false,
"start_time": "2022-04-18T16:06:42.445737",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- step: integer (nullable = true)\n",
" |-- type: string (nullable = true)\n",
" |-- amount: double (nullable = true)\n",
" |-- nameOrig: string (nullable = true)\n",
" |-- oldBalanceOrig: double (nullable = true)\n",
" |-- newBalanceOrig: double (nullable = true)\n",
" |-- nameDest: string (nullable = true)\n",
" |-- oldBalanceDest: double (nullable = true)\n",
" |-- newBalanceDest: double (nullable = true)\n",
" |-- isFraud: integer (nullable = true)\n",
" |-- isFlaggedFraud: integer (nullable = true)\n",
"\n"
]
}
],
"source": [
"from pyspark.sql import types as T\n",
"\n",
"predefined_schema = T.StructType([\n",
" T.StructField('step', T.IntegerType()),\n",
" T.StructField('type', T.StringType()),\n",
" T.StructField('amount', T.DoubleType()),\n",
" T.StructField('nameOrig', T.StringType()),\n",
" T.StructField('oldbalanceOrg', T.DoubleType()),\n",
" T.StructField('newbalanceOrig', T.DoubleType()), \n",
" T.StructField('nameDest', T.StringType()),\n",
" T.StructField('oldbalanceDest', T.DoubleType()),\n",
" T.StructField('newbalanceDest', T.DoubleType()), \n",
" T.StructField('isFraud', T.IntegerType()),\n",
" T.StructField('isFlaggedFraud', T.IntegerType())\n",
"])\n",
"\n",
"df = spark.read.csv(data_path, schema=predefined_schema, header=True)\n",
"\n",
"corrected_cols = {'oldbalanceOrg': 'oldBalanceOrig', 'newbalanceOrig': 'newBalanceOrig', \n",
" 'oldbalanceDest': 'oldBalanceDest', 'newbalanceDest': 'newBalanceDest'}\n",
"for old_col, new_col in corrected_cols.items():\n",
" df = df.withColumnRenamed(old_col, new_col)\n",
"\n",
"df.printSchema()"
]
},
{
"cell_type": "markdown",
"id": "9598135d",
"metadata": {
"papermill": {
"duration": 0.246272,
"end_time": "2022-04-18T16:06:43.342538",
"exception": false,
"start_time": "2022-04-18T16:06:43.096266",
"status": "completed"
},
"tags": []
},
"source": [
"## Overview"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6b477790",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:06:43.866558Z",
"iopub.status.busy": "2022-04-18T16:06:43.865857Z",
"iopub.status.idle": "2022-04-18T16:06:44.314351Z",
"shell.execute_reply": "2022-04-18T16:06:44.313470Z",
"shell.execute_reply.started": "2022-04-18T14:26:26.666579Z"
},
"papermill": {
"duration": 0.726734,
"end_time": "2022-04-18T16:06:44.314550",
"exception": false,
"start_time": "2022-04-18T16:06:43.587816",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----+--------+--------+-----------+--------------+--------------+-----------+--------------+--------------+-------+--------------+\n",
"|step| type| amount| nameOrig|oldBalanceOrig|newBalanceOrig| nameDest|oldBalanceDest|newBalanceDest|isFraud|isFlaggedFraud|\n",
"+----+--------+--------+-----------+--------------+--------------+-----------+--------------+--------------+-------+--------------+\n",
"| 1| PAYMENT| 9839.64|C1231006815| 170136.0| 160296.36|M1979787155| 0.0| 0.0| 0| 0|\n",
"| 1| PAYMENT| 1864.28|C1666544295| 21249.0| 19384.72|M2044282225| 0.0| 0.0| 0| 0|\n",
"| 1|TRANSFER| 181.0|C1305486145| 181.0| 0.0| C553264065| 0.0| 0.0| 1| 0|\n",
"| 1|CASH_OUT| 181.0| C840083671| 181.0| 0.0| C38997010| 21182.0| 0.0| 1| 0|\n",
"| 1| PAYMENT|11668.14|C2048537720| 41554.0| 29885.86|M1230701703| 0.0| 0.0| 0| 0|\n",
"| 1| PAYMENT| 7817.71| C90045638| 53860.0| 46042.29| M573487274| 0.0| 0.0| 0| 0|\n",
"| 1| PAYMENT| 7107.77| C154988899| 183195.0| 176087.23| M408069119| 0.0| 0.0| 0| 0|\n",
"| 1| PAYMENT| 7861.64|C1912850431| 176087.23| 168225.59| M633326333| 0.0| 0.0| 0| 0|\n",
"| 1| PAYMENT| 4024.36|C1265012928| 2671.0| 0.0|M1176932104| 0.0| 0.0| 0| 0|\n",
"| 1| DEBIT| 5337.77| C712410124| 41720.0| 36382.23| C195600860| 41898.0| 40348.79| 0| 0|\n",
"+----+--------+--------+-----------+--------------+--------------+-----------+--------------+--------------+-------+--------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"df.show(10)"
]
},
{
"cell_type": "markdown",
"id": "35fb6e82",
"metadata": {
"papermill": {
"duration": 0.262869,
"end_time": "2022-04-18T16:06:44.869220",
"exception": false,
"start_time": "2022-04-18T16:06:44.606351",
"status": "completed"
},
"tags": []
},
"source": [
"Sometimes, the dataframe will not fit the screen. We can adjust the parameters of `show` function or split the dataframe to show."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c536f73e",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:06:45.376296Z",
"iopub.status.busy": "2022-04-18T16:06:45.375621Z",
"iopub.status.idle": "2022-04-18T16:06:46.094423Z",
"shell.execute_reply": "2022-04-18T16:06:46.093542Z",
"shell.execute_reply.started": "2022-04-18T14:26:27.080815Z"
},
"papermill": {
"duration": 0.97618,
"end_time": "2022-04-18T16:06:46.094623",
"exception": false,
"start_time": "2022-04-18T16:06:45.118443",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----+--------+--------+-----------+\n",
"|step| type| amount| nameOrig|\n",
"+----+--------+--------+-----------+\n",
"| 1| PAYMENT| 9839.64|C1231006815|\n",
"| 1| PAYMENT| 1864.28|C1666544295|\n",
"| 1|TRANSFER| 181.0|C1305486145|\n",
"| 1|CASH_OUT| 181.0| C840083671|\n",
"| 1| PAYMENT|11668.14|C2048537720|\n",
"| 1| PAYMENT| 7817.71| C90045638|\n",
"| 1| PAYMENT| 7107.77| C154988899|\n",
"| 1| PAYMENT| 7861.64|C1912850431|\n",
"| 1| PAYMENT| 4024.36|C1265012928|\n",
"| 1| DEBIT| 5337.77| C712410124|\n",
"+----+--------+--------+-----------+\n",
"only showing top 10 rows\n",
"\n",
"+--------------+--------------+-----------+--------------+\n",
"|oldBalanceOrig|newBalanceOrig| nameDest|oldBalanceDest|\n",
"+--------------+--------------+-----------+--------------+\n",
"| 170136.0| 160296.36|M1979787155| 0.0|\n",
"| 21249.0| 19384.72|M2044282225| 0.0|\n",
"| 181.0| 0.0| C553264065| 0.0|\n",
"| 181.0| 0.0| C38997010| 21182.0|\n",
"| 41554.0| 29885.86|M1230701703| 0.0|\n",
"| 53860.0| 46042.29| M573487274| 0.0|\n",
"| 183195.0| 176087.23| M408069119| 0.0|\n",
"| 176087.23| 168225.59| M633326333| 0.0|\n",
"| 2671.0| 0.0|M1176932104| 0.0|\n",
"| 41720.0| 36382.23| C195600860| 41898.0|\n",
"+--------------+--------------+-----------+--------------+\n",
"only showing top 10 rows\n",
"\n",
"+--------------+-------+--------------+\n",
"|newBalanceDest|isFraud|isFlaggedFraud|\n",
"+--------------+-------+--------------+\n",
"| 0.0| 0| 0|\n",
"| 0.0| 0| 0|\n",
"| 0.0| 1| 0|\n",
"| 0.0| 1| 0|\n",
"| 0.0| 0| 0|\n",
"| 0.0| 0| 0|\n",
"| 0.0| 0| 0|\n",
"| 0.0| 0| 0|\n",
"| 0.0| 0| 0|\n",
"| 40348.79| 0| 0|\n",
"+--------------+-------+--------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"\n",
"def show_split(df, split=-1, n_samples=10):\n",
" n_cols = len(df.columns)\n",
" if split <= 0:\n",
" split = n_cols\n",
" i = 0\n",
" j = i + split\n",
" while i < n_cols:\n",
" df.select(*df.columns[i:j]).show(n_samples)\n",
" i = j\n",
" j = i + split\n",
" \n",
"show_split(df, 4, 10)"
]
},
{
"cell_type": "markdown",
"id": "d2dcb89e",
"metadata": {
"papermill": {
"duration": 0.25169,
"end_time": "2022-04-18T16:06:46.712129",
"exception": false,
"start_time": "2022-04-18T16:06:46.460439",
"status": "completed"
},
"tags": []
},
"source": [
"When working with numerical data, looking at a long column of values isn’t very useful. We’re often more concerned about some key information, which may include count, mean, standard deviation, minimum, and maximum"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d3f108ba",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:06:47.217691Z",
"iopub.status.busy": "2022-04-18T16:06:47.217017Z",
"iopub.status.idle": "2022-04-18T16:06:54.615194Z",
"shell.execute_reply": "2022-04-18T16:06:54.614521Z",
"shell.execute_reply.started": "2022-04-18T14:26:27.760536Z"
},
"papermill": {
"duration": 7.652204,
"end_time": "2022-04-18T16:06:54.615348",
"exception": false,
"start_time": "2022-04-18T16:06:46.963144",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 6:==============> (1 + 3) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-------+------------------+------------------+\n",
"|summary| step| amount|\n",
"+-------+------------------+------------------+\n",
"| count| 6362620| 6362620|\n",
"| mean|243.39724563151657|179861.90354913412|\n",
"| stddev|142.33197104912588| 603858.2314629498|\n",
"| min| 1| 0.0|\n",
"| max| 743| 9.244551664E7|\n",
"+-------+------------------+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"# describe take columns as params\n",
"df.describe('step', 'amount').show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "796149f9",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:06:55.210204Z",
"iopub.status.busy": "2022-04-18T16:06:55.209232Z",
"iopub.status.idle": "2022-04-18T16:07:14.195945Z",
"shell.execute_reply": "2022-04-18T16:07:14.196698Z",
"shell.execute_reply.started": "2022-04-18T14:26:34.014826Z"
},
"papermill": {
"duration": 19.244938,
"end_time": "2022-04-18T16:07:14.196963",
"exception": false,
"start_time": "2022-04-18T16:06:54.952025",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 9:============================================> (3 + 1) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-------+-----------------+-----------------+------------------+------------------+\n",
"|summary| oldBalanceOrig| newBalanceOrig| oldBalanceDest| newBalanceDest|\n",
"+-------+-----------------+-----------------+------------------+------------------+\n",
"| count| 6362620| 6362620| 6362620| 6362620|\n",
"| min| 0.0| 0.0| 0.0| 0.0|\n",
"| max| 5.958504037E7| 4.958504037E7| 3.5601588935E8| 3.5617927892E8|\n",
"| mean|833883.1040744719|855113.6685785714|1100701.6665196654|1224996.3982019408|\n",
"| 50%| 14211.23| 0.0| 132612.49| 214605.81|\n",
"+-------+-----------------+-----------------+------------------+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"# summary take statistics as params\n",
"df.select('oldBalanceOrig', 'newBalanceOrig', 'oldBalanceDest', 'newBalanceDest').summary('count', 'min', 'max', 'mean', '50%').show()"
]
},
{
"cell_type": "markdown",
"id": "2cecec1b",
"metadata": {
"papermill": {
"duration": 0.263635,
"end_time": "2022-04-18T16:07:14.762155",
"exception": false,
"start_time": "2022-04-18T16:07:14.498520",
"status": "completed"
},
"tags": []
},
"source": [
"## Explore"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4515be3f",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:15.279378Z",
"iopub.status.busy": "2022-04-18T16:07:15.278723Z",
"iopub.status.idle": "2022-04-18T16:07:15.282736Z",
"shell.execute_reply": "2022-04-18T16:07:15.283317Z",
"shell.execute_reply.started": "2022-04-18T14:26:50.226832Z"
},
"papermill": {
"duration": 0.269074,
"end_time": "2022-04-18T16:07:15.283489",
"exception": false,
"start_time": "2022-04-18T16:07:15.014415",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from pyspark.sql import functions as F"
]
},
{
"cell_type": "markdown",
"id": "1058119e",
"metadata": {
"papermill": {
"duration": 0.258828,
"end_time": "2022-04-18T16:07:15.800147",
"exception": false,
"start_time": "2022-04-18T16:07:15.541319",
"status": "completed"
},
"tags": []
},
"source": [
"### Querying columns and rows"
]
},
{
"cell_type": "markdown",
"id": "d0338ba7",
"metadata": {
"papermill": {
"duration": 0.250465,
"end_time": "2022-04-18T16:07:16.303802",
"exception": false,
"start_time": "2022-04-18T16:07:16.053337",
"status": "completed"
},
"tags": []
},
"source": [
"PySpark borrowed a lot of vocabulary from the SQL world. But it do not need to follow the strict SQL framework (select *what* from *where* where *condition met* ...). Each step of pyspark syntax will return a `DataFrame` or `GroupedData` which we can continue to work with flawlessly."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e049e111",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:16.825664Z",
"iopub.status.busy": "2022-04-18T16:07:16.824998Z",
"iopub.status.idle": "2022-04-18T16:07:17.146094Z",
"shell.execute_reply": "2022-04-18T16:07:17.145027Z",
"shell.execute_reply.started": "2022-04-18T14:26:50.246607Z"
},
"papermill": {
"duration": 0.593217,
"end_time": "2022-04-18T16:07:17.146320",
"exception": false,
"start_time": "2022-04-18T16:07:16.553103",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+---------+\n",
"| type| amount|\n",
"+--------+---------+\n",
"|CASH_OUT| 181.0|\n",
"|CASH_OUT|229133.94|\n",
"|CASH_OUT|110414.71|\n",
"|CASH_OUT| 56953.9|\n",
"|CASH_OUT| 5346.89|\n",
"|CASH_OUT| 23261.3|\n",
"|CASH_OUT| 82940.31|\n",
"|CASH_OUT| 47458.86|\n",
"|CASH_OUT|136872.92|\n",
"|CASH_OUT| 94253.33|\n",
"+--------+---------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"df.where(df['type']=='CASH_OUT').select(df.type, F.col('amount')).show(10)"
]
},
{
"cell_type": "markdown",
"id": "00d18fca",
"metadata": {
"papermill": {
"duration": 0.261561,
"end_time": "2022-04-18T16:07:17.696166",
"exception": false,
"start_time": "2022-04-18T16:07:17.434605",
"status": "completed"
},
"tags": []
},
"source": [
"Above code show us three different ways to access to pyspark columns, `df['tpye']`, `df.type`, and `F.col('type')`"
]
},
{
"cell_type": "markdown",
"id": "6a72deab",
"metadata": {
"papermill": {
"duration": 0.25219,
"end_time": "2022-04-18T16:07:18.204287",
"exception": false,
"start_time": "2022-04-18T16:07:17.952097",
"status": "completed"
},
"tags": []
},
"source": [
"We also can use SQL code inside pyspark. But we must first register the `DataFrame` with spark sql environment"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "603ae074",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:18.711105Z",
"iopub.status.busy": "2022-04-18T16:07:18.710431Z",
"iopub.status.idle": "2022-04-18T16:07:19.010606Z",
"shell.execute_reply": "2022-04-18T16:07:19.009745Z",
"shell.execute_reply.started": "2022-04-18T14:26:50.708957Z"
},
"papermill": {
"duration": 0.554706,
"end_time": "2022-04-18T16:07:19.010754",
"exception": false,
"start_time": "2022-04-18T16:07:18.456048",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+---------+\n",
"| type| amount|\n",
"+--------+---------+\n",
"|CASH_OUT| 181.0|\n",
"|CASH_OUT|229133.94|\n",
"|CASH_OUT|110414.71|\n",
"|CASH_OUT| 56953.9|\n",
"|CASH_OUT| 5346.89|\n",
"|CASH_OUT| 23261.3|\n",
"|CASH_OUT| 82940.31|\n",
"|CASH_OUT| 47458.86|\n",
"|CASH_OUT|136872.92|\n",
"|CASH_OUT| 94253.33|\n",
"+--------+---------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"df.createOrReplaceTempView('df')\n",
"spark.sql('''\n",
" SELECT type, amount FROM df\n",
" WHERE type = \"CASH_OUT\" \n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "dad22034",
"metadata": {
"papermill": {
"duration": 0.262527,
"end_time": "2022-04-18T16:07:19.546975",
"exception": false,
"start_time": "2022-04-18T16:07:19.284448",
"status": "completed"
},
"tags": []
},
"source": [
"While I prefer to use pyspark's way. But there is some case writing SQL will be beneficial as string is more readable than a lot of boilerplate code."
]
},
{
"cell_type": "markdown",
"id": "6b08f3fa",
"metadata": {
"papermill": {
"duration": 0.251924,
"end_time": "2022-04-18T16:07:20.058978",
"exception": false,
"start_time": "2022-04-18T16:07:19.807054",
"status": "completed"
},
"tags": []
},
"source": [
"### Grouping"
]
},
{
"cell_type": "markdown",
"id": "0f7b4353",
"metadata": {
"papermill": {
"duration": 0.262019,
"end_time": "2022-04-18T16:07:20.572548",
"exception": false,
"start_time": "2022-04-18T16:07:20.310529",
"status": "completed"
},
"tags": []
},
"source": [
"pyspark provides `Column.alias` method to change the name of the column in the result."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1f39409e",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:21.085222Z",
"iopub.status.busy": "2022-04-18T16:07:21.084165Z",
"iopub.status.idle": "2022-04-18T16:07:29.237917Z",
"shell.execute_reply": "2022-04-18T16:07:29.238837Z",
"shell.execute_reply.started": "2022-04-18T14:26:51.004866Z"
},
"papermill": {
"duration": 8.414507,
"end_time": "2022-04-18T16:07:29.239133",
"exception": false,
"start_time": "2022-04-18T16:07:20.824626",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 14:==============> (1 + 3) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+------------------+\n",
"| type| avgAmount|\n",
"+--------+------------------+\n",
"| DEBIT| 5483.665313767128|\n",
"| PAYMENT|13057.604660187604|\n",
"| CASH_IN| 168920.2420040954|\n",
"|CASH_OUT|176273.96434613998|\n",
"|TRANSFER| 910647.0096454868|\n",
"+--------+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"#Sometimes we can pass column name directly to pyspark function\n",
"df.select('type', 'amount').groupBy('type').agg(F.mean('amount').alias('avgAmount')).orderBy('avgAmount').show(10)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "754ecaf9",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:29.897379Z",
"iopub.status.busy": "2022-04-18T16:07:29.896184Z",
"iopub.status.idle": "2022-04-18T16:07:38.222453Z",
"shell.execute_reply": "2022-04-18T16:07:38.221443Z",
"shell.execute_reply.started": "2022-04-18T14:26:57.324339Z"
},
"papermill": {
"duration": 8.589937,
"end_time": "2022-04-18T16:07:38.222670",
"exception": false,
"start_time": "2022-04-18T16:07:29.632733",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 17:===========================================> (3 + 1) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+------------------+\n",
"| type| avgAmount|\n",
"+--------+------------------+\n",
"| DEBIT| 5483.665313767128|\n",
"| PAYMENT|13057.604660187604|\n",
"| CASH_IN| 168920.2420040954|\n",
"|CASH_OUT|176273.96434613998|\n",
"|TRANSFER| 910647.0096454868|\n",
"+--------+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"spark.sql('''\n",
" SELECT type, AVG(amount) avgAmount FROM df\n",
" GROUP BY type\n",
" ORDER BY 2\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "68027a44",
"metadata": {
"papermill": {
"duration": 0.279305,
"end_time": "2022-04-18T16:07:38.815925",
"exception": false,
"start_time": "2022-04-18T16:07:38.536620",
"status": "completed"
},
"tags": []
},
"source": [
"To filter result after grouping, we can just simply apply `where` or `filter` to result `DataFrame` object or follow SQL framework with `having` keyword."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "c9dd68d3",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:39.337683Z",
"iopub.status.busy": "2022-04-18T16:07:39.336980Z",
"iopub.status.idle": "2022-04-18T16:07:49.235799Z",
"shell.execute_reply": "2022-04-18T16:07:49.236575Z",
"shell.execute_reply.started": "2022-04-18T14:27:03.222215Z"
},
"papermill": {
"duration": 10.161799,
"end_time": "2022-04-18T16:07:49.236769",
"exception": false,
"start_time": "2022-04-18T16:07:39.074970",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 20:> (0 + 4) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---------+\n",
"| nameOrig|sumAmount|\n",
"+-----------+---------+\n",
"| C551314014|301050.58|\n",
"| C661668091|323789.56|\n",
"| C228994633|517946.01|\n",
"|C1591008292|558254.22|\n",
"|C2100435651|357988.09|\n",
"| C624052656|476735.47|\n",
"| C948681098|353759.28|\n",
"| C50682517|386128.82|\n",
"|C1579521009|684561.18|\n",
"|C1871922377|394317.12|\n",
"+-----------+---------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"(\n",
" df.where(df['type']=='CASH_OUT')\n",
" .groupBy('nameOrig')\n",
" .agg(F.sum('amount').alias('sumAmount'))\n",
" .where(F.col('sumAmount') > 300000)\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "70ebd597",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:49.792716Z",
"iopub.status.busy": "2022-04-18T16:07:49.791618Z",
"iopub.status.idle": "2022-04-18T16:07:58.216999Z",
"shell.execute_reply": "2022-04-18T16:07:58.216427Z",
"shell.execute_reply.started": "2022-04-18T14:27:11.428928Z"
},
"papermill": {
"duration": 8.71375,
"end_time": "2022-04-18T16:07:58.217148",
"exception": false,
"start_time": "2022-04-18T16:07:49.503398",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 23:==============> (1 + 3) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---------+\n",
"| nameOrig|sumAmount|\n",
"+-----------+---------+\n",
"| C551314014|301050.58|\n",
"| C661668091|323789.56|\n",
"| C228994633|517946.01|\n",
"|C1591008292|558254.22|\n",
"|C2100435651|357988.09|\n",
"| C624052656|476735.47|\n",
"| C948681098|353759.28|\n",
"| C50682517|386128.82|\n",
"|C1579521009|684561.18|\n",
"|C1871922377|394317.12|\n",
"+-----------+---------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"spark.sql('''\n",
" SELECT nameOrig, SUM(amount) sumAmount FROM df\n",
" WHERE type = \"CASH_OUT\"\n",
" GROUP BY 1\n",
" HAVING sumAmount > 300000\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "e231781f",
"metadata": {
"papermill": {
"duration": 0.255761,
"end_time": "2022-04-18T16:07:58.752136",
"exception": false,
"start_time": "2022-04-18T16:07:58.496375",
"status": "completed"
},
"tags": []
},
"source": [
"### Create table/views"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "816647f1",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:07:59.276355Z",
"iopub.status.busy": "2022-04-18T16:07:59.275696Z",
"iopub.status.idle": "2022-04-18T16:07:59.278299Z",
"shell.execute_reply": "2022-04-18T16:07:59.278817Z",
"shell.execute_reply.started": "2022-04-18T14:27:18.299459Z"
},
"papermill": {
"duration": 0.26778,
"end_time": "2022-04-18T16:07:59.279039",
"exception": false,
"start_time": "2022-04-18T16:07:59.011259",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"# CREATE OR REPLACE TEMP VIEW name_of_view AS SELECT ... FROM ..."
]
},
{
"cell_type": "markdown",
"id": "7a1b1ba7",
"metadata": {
"papermill": {
"duration": 0.258798,
"end_time": "2022-04-18T16:07:59.793576",
"exception": false,
"start_time": "2022-04-18T16:07:59.534778",
"status": "completed"
},
"tags": []
},
"source": [
"### Uinon and Intersect"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "cd97ddd1",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:08:00.317960Z",
"iopub.status.busy": "2022-04-18T16:08:00.317283Z",
"iopub.status.idle": "2022-04-18T16:08:04.926473Z",
"shell.execute_reply": "2022-04-18T16:08:04.927303Z",
"shell.execute_reply.started": "2022-04-18T14:27:18.305533Z"
},
"papermill": {
"duration": 4.873564,
"end_time": "2022-04-18T16:08:04.927557",
"exception": false,
"start_time": "2022-04-18T16:08:00.053993",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"data": {
"text/plain": [
"12725240"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.select('nameOrig').union(df.select('nameDest')).count()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "23e6f44c",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:08:05.655396Z",
"iopub.status.busy": "2022-04-18T16:08:05.654333Z",
"iopub.status.idle": "2022-04-18T16:08:33.968916Z",
"shell.execute_reply": "2022-04-18T16:08:33.967981Z",
"shell.execute_reply.started": "2022-04-18T14:27:21.586558Z"
},
"papermill": {
"duration": 28.602295,
"end_time": "2022-04-18T16:08:33.969166",
"exception": false,
"start_time": "2022-04-18T16:08:05.366871",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"data": {
"text/plain": [
"9073900"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"spark.sql('''\n",
" SELECT nameOrig from df\n",
" UNION\n",
" SELECT nameDest from df\n",
"''').count()"
]
},
{
"cell_type": "markdown",
"id": "31da8ad8",
"metadata": {
"papermill": {
"duration": 0.260411,
"end_time": "2022-04-18T16:08:34.520400",
"exception": false,
"start_time": "2022-04-18T16:08:34.259989",
"status": "completed"
},
"tags": []
},
"source": [
"We can see the difference here. The reason is `union()` function in pyspark will keep the duplicate samples from two sets. It is equivalent to `UNION ALL` in SQL. Remove duplication is a costly process."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "0fa9cb8b",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:08:35.100012Z",
"iopub.status.busy": "2022-04-18T16:08:35.098953Z",
"iopub.status.idle": "2022-04-18T16:09:01.432422Z",
"shell.execute_reply": "2022-04-18T16:09:01.431381Z",
"shell.execute_reply.started": "2022-04-18T14:27:46.539315Z"
},
"papermill": {
"duration": 26.645472,
"end_time": "2022-04-18T16:09:01.432650",
"exception": false,
"start_time": "2022-04-18T16:08:34.787178",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"data": {
"text/plain": [
"9073900"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.select('nameOrig').union(df.select('nameDest')).dropDuplicates().count()"
]
},
{
"cell_type": "markdown",
"id": "257172c2",
"metadata": {
"papermill": {
"duration": 0.266071,
"end_time": "2022-04-18T16:09:01.964831",
"exception": false,
"start_time": "2022-04-18T16:09:01.698760",
"status": "completed"
},
"tags": []
},
"source": [
"Unioning can be beneficial when we read data from multiple files. We can read them one by one and union them."
]
},
{
"cell_type": "markdown",
"id": "5259772c",
"metadata": {
"papermill": {
"duration": 0.264503,
"end_time": "2022-04-18T16:09:02.500225",
"exception": false,
"start_time": "2022-04-18T16:09:02.235722",
"status": "completed"
},
"tags": []
},
"source": [
"### Joining"
]
},
{
"cell_type": "markdown",
"id": "587a310e",
"metadata": {
"papermill": {
"duration": 0.264936,
"end_time": "2022-04-18T16:09:03.030265",
"exception": false,
"start_time": "2022-04-18T16:09:02.765329",
"status": "completed"
},
"tags": []
},
"source": [
"Sometimes, we can mix SQL code with pyspark to gain readability. Functions support: `selectExpr`, `where`, `filter`, `expr`,..."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "85943200",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:09:03.569786Z",
"iopub.status.busy": "2022-04-18T16:09:03.568718Z",
"iopub.status.idle": "2022-04-18T16:09:36.071105Z",
"shell.execute_reply": "2022-04-18T16:09:36.070352Z",
"shell.execute_reply.started": "2022-04-18T14:28:08.367603Z"
},
"papermill": {
"duration": 32.773863,
"end_time": "2022-04-18T16:09:36.071370",
"exception": false,
"start_time": "2022-04-18T16:09:03.297507",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 45:> (0 + 4) / 5]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---+------------------+------------------+\n",
"| name|occ| avgChangeOrig| avgChangeDest|\n",
"+-----------+---+------------------+------------------+\n",
"|C1552859894| 43|193711.30000000005| 763241.1652380949|\n",
"|C1819271729| 37| 278937.79|283626.17805555544|\n",
"|C1692434834| 37|177369.73000000045| 438853.7616666666|\n",
"| C889762313| 32| 132731.31|211437.18741935486|\n",
"|C1868986147| 32| 120594.03|249840.37709677417|\n",
"| C55305556| 28|319860.45999999903|225565.42111111112|\n",
"| C636092700| 26|217273.86000000004|201888.05279999998|\n",
"|C1713505653| 25| 278622.8400000003|186625.34916666665|\n",
"|C2029542508| 24| 235760.1200000001|231022.98217391354|\n",
"| C699906968| 23| 177813.3799999999| 183054.3072727272|\n",
"+-----------+---+------------------+------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"(\n",
" df.where('type = \"CASH_IN\" OR type = \"CASH_OUT\"')\n",
" .selectExpr('nameOrig', 'ABS(newBalanceOrig - oldBalanceOrig) changeOrig')\n",
" .groupBy('nameOrig')\n",
" .agg(\n",
" F.mean(F.col('changeOrig')).alias('avgChangeOrig'),\n",
" F.count('*').alias('occOrig')\n",
" )\n",
" .where('avgChangeOrig > 100000')\n",
" .join((\n",
" df.where('type = \"CASH_IN\" OR type = \"CASH_OUT\"')\n",
" .selectExpr('nameDest', 'ABS(newBalanceDest - oldBalanceDest) changeDest')\n",
" .groupBy('nameDest')\n",
" .agg(\n",
" F.mean(F.col('changeDest')).alias('avgChangeDest'),\n",
" F.count('*').alias('occDest')\n",
" )\n",
" .where('avgChangeDest > 100000')\n",
" ), on=F.col('nameOrig')==F.col('nameDest'), how='inner')\n",
" .selectExpr('nameOrig name', 'occOrig + occDest occ', 'avgChangeOrig', 'avgChangeDest')\n",
" .orderBy('occ', ascending=False)\n",
").show(10)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "3ce7bdc8",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:09:36.650919Z",
"iopub.status.busy": "2022-04-18T16:09:36.650223Z",
"iopub.status.idle": "2022-04-18T16:09:59.622409Z",
"shell.execute_reply": "2022-04-18T16:09:59.623244Z",
"shell.execute_reply.started": "2022-04-18T14:28:27.953111Z"
},
"papermill": {
"duration": 23.265558,
"end_time": "2022-04-18T16:09:59.623496",
"exception": false,
"start_time": "2022-04-18T16:09:36.357938",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 50:> (0 + 4) / 5]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---+------------------+------------------+\n",
"| name|occ| avgChangeOrig| avgChangeDest|\n",
"+-----------+---+------------------+------------------+\n",
"|C1552859894| 43|193711.30000000005| 763241.1652380949|\n",
"|C1819271729| 37| 278937.79|283626.17805555544|\n",
"|C1692434834| 37|177369.73000000045| 438853.7616666666|\n",
"| C889762313| 32| 132731.31|211437.18741935486|\n",
"|C1868986147| 32| 120594.03|249840.37709677417|\n",
"| C55305556| 28|319860.45999999903|225565.42111111112|\n",
"| C636092700| 26|217273.86000000004|201888.05279999998|\n",
"|C1713505653| 25| 278622.8400000003|186625.34916666665|\n",
"|C2029542508| 24| 235760.1200000001|231022.98217391354|\n",
"| C699906968| 23| 177813.3799999999| 183054.3072727272|\n",
"+-----------+---+------------------+------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"spark.sql('''\n",
" SELECT nameOrig name, occOrig + occDest occ, avgChangeOrig, avgChangeDest\n",
" FROM\n",
" (\n",
" SELECT nameOrig, AVG(ABS(newBalanceOrig - oldBalanceOrig)) avgChangeOrig, COUNT(*) occOrig\n",
" FROM df\n",
" WHERE type = \"CASH_IN\" OR type = \"CASH_OUT\"\n",
" GROUP BY nameOrig\n",
" HAVING avgChangeOrig > 100000\n",
" )\n",
" INNER JOIN\n",
" (\n",
" SELECT nameDest, AVG(ABS(newBalanceDest - oldBalanceDest)) avgChangeDest, COUNT(*) occDest\n",
" FROM df\n",
" WHERE type = \"CASH_IN\" OR type = \"CASH_OUT\"\n",
" GROUP BY nameDest\n",
" HAVING avgChangeDest > 100000\n",
" )\n",
" ON nameOrig = nameDest\n",
" ORDER BY occ DESC\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "0f089874",
"metadata": {
"papermill": {
"duration": 0.279977,
"end_time": "2022-04-18T16:10:00.207625",
"exception": false,
"start_time": "2022-04-18T16:09:59.927648",
"status": "completed"
},
"tags": []
},
"source": [
"There are several join method: `inner`, `left`, `right`, `cross`, `outer`, `left_outer`, `right_outer`, `left_semi`, `left_anti`, `right_semi`, `right_anti`, ..."
]
},
{
"cell_type": "markdown",
"id": "783cd337",
"metadata": {
"papermill": {
"duration": 0.281242,
"end_time": "2022-04-18T16:10:00.762668",
"exception": false,
"start_time": "2022-04-18T16:10:00.481426",
"status": "completed"
},
"tags": []
},
"source": [
"### Subqueries"
]
},
{
"cell_type": "markdown",
"id": "2c783ed6",
"metadata": {
"papermill": {
"duration": 0.275711,
"end_time": "2022-04-18T16:10:01.312404",
"exception": false,
"start_time": "2022-04-18T16:10:01.036693",
"status": "completed"
},
"tags": []
},
"source": [
"In a nutshell, subqueries allow you to create tables local to your query. In Python, this is similar to using the `with` statement or using a function block to limit the scope of a query. It is very helpful in keeping our code clean"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "60b58296",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:10:01.902252Z",
"iopub.status.busy": "2022-04-18T16:10:01.901503Z",
"iopub.status.idle": "2022-04-18T16:10:24.597708Z",
"shell.execute_reply": "2022-04-18T16:10:24.596710Z",
"shell.execute_reply.started": "2022-04-18T14:28:45.130235Z"
},
"papermill": {
"duration": 23.009323,
"end_time": "2022-04-18T16:10:24.597962",
"exception": false,
"start_time": "2022-04-18T16:10:01.588639",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 55:> (0 + 4) / 5]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---+------------------+------------------+\n",
"| name|occ| avgChangeOrig| avgChangeDest|\n",
"+-----------+---+------------------+------------------+\n",
"|C1552859894| 43|193711.30000000005| 763241.1652380949|\n",
"|C1819271729| 37| 278937.79|283626.17805555544|\n",
"|C1692434834| 37|177369.73000000045| 438853.7616666666|\n",
"| C889762313| 32| 132731.31|211437.18741935486|\n",
"|C1868986147| 32| 120594.03|249840.37709677417|\n",
"| C55305556| 28|319860.45999999903|225565.42111111112|\n",
"| C636092700| 26|217273.86000000004|201888.05279999998|\n",
"|C1713505653| 25| 278622.8400000003|186625.34916666665|\n",
"|C2029542508| 24| 235760.1200000001|231022.98217391354|\n",
"| C699906968| 23| 177813.3799999999| 183054.3072727272|\n",
"+-----------+---+------------------+------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"def temp_var_not_avail_outside_function():\n",
" orig_df = (\n",
" df.where((df.type=='CASH_IN') | (df.type=='CASH_OUT'))\n",
" .select('nameOrig', F.abs(df.newBalanceOrig - df.oldBalanceOrig).alias('changeOrig'))\n",
" .groupBy('nameOrig')\n",
" .agg(\n",
" F.mean(F.col('changeOrig')).alias('avgChangeOrig'),\n",
" F.count('*').alias('occOrig')\n",
" )\n",
" .where(F.col('avgChangeOrig') > 100000)\n",
" )\n",
" dest_df = (\n",
" df.where((df.type=='CASH_IN') | (df.type=='CASH_OUT'))\n",
" .select('nameDest', F.abs(df.newBalanceDest - df.oldBalanceDest).alias('changeDest'))\n",
" .groupBy('nameDest')\n",
" .agg(\n",
" F.mean(F.col('changeDest')).alias('avgChangeDest'),\n",
" F.count('*').alias('occDest')\n",
" )\n",
" .where(F.col('avgChangeDest') > 100000)\n",
" )\n",
" # Main query\n",
" (\n",
" orig_df.join(dest_df, on=orig_df.nameOrig==dest_df.nameDest, how='inner')\n",
" .select(\n",
" F.col('nameOrig').alias('name'), \n",
" (F.col('occOrig') + F.col('occDest')).alias('occ'),\n",
" 'avgChangeOrig', 'avgChangeDest' \n",
" )\n",
" .orderBy('occ', ascending=False)\n",
" .show(10)\n",
" )\n",
" \n",
"\n",
"temp_var_not_avail_outside_function()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "e0a79ad8",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:10:25.177415Z",
"iopub.status.busy": "2022-04-18T16:10:25.176240Z",
"iopub.status.idle": "2022-04-18T16:10:48.027842Z",
"shell.execute_reply": "2022-04-18T16:10:48.026869Z",
"shell.execute_reply.started": "2022-04-18T14:29:01.985654Z"
},
"papermill": {
"duration": 23.131135,
"end_time": "2022-04-18T16:10:48.028078",
"exception": false,
"start_time": "2022-04-18T16:10:24.896943",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 60:> (0 + 4) / 5]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---+------------------+------------------+\n",
"| name|occ| avgChangeOrig| avgChangeDest|\n",
"+-----------+---+------------------+------------------+\n",
"|C1552859894| 43|193711.30000000005| 763241.1652380949|\n",
"|C1692434834| 37|177369.73000000045| 438853.7616666666|\n",
"|C1819271729| 37| 278937.79|283626.17805555544|\n",
"| C889762313| 32| 132731.31|211437.18741935486|\n",
"|C1868986147| 32| 120594.03|249840.37709677417|\n",
"| C55305556| 28|319860.45999999903|225565.42111111112|\n",
"| C636092700| 26|217273.86000000004|201888.05279999998|\n",
"|C1713505653| 25| 278622.8400000003|186625.34916666665|\n",
"|C2029542508| 24| 235760.1200000001|231022.98217391354|\n",
"| C699906968| 23| 177813.3799999999| 183054.3072727272|\n",
"+-----------+---+------------------+------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"spark.sql('''\n",
" WITH \n",
" origDf as (\n",
" SELECT nameOrig, AVG(ABS(newBalanceOrig - oldBalanceOrig)) avgChangeOrig, COUNT(*) occOrig\n",
" FROM df\n",
" WHERE type = \"CASH_IN\" OR type = \"CASH_OUT\"\n",
" GROUP BY nameOrig\n",
" HAVING avgChangeOrig > 100000\n",
" ),\n",
" destDf as (\n",
" SELECT nameDest, AVG(ABS(newBalanceDest - oldBalanceDest)) avgChangeDest, COUNT(*) occDest\n",
" FROM df\n",
" WHERE type = \"CASH_IN\" OR type = \"CASH_OUT\"\n",
" GROUP BY nameDest\n",
" HAVING avgChangeDest > 100000\n",
" )\n",
" SELECT nameOrig name, occOrig + occDest occ, avgChangeOrig, avgChangeDest\n",
" FROM origDf INNER JOIN destDf ON nameOrig = nameDest\n",
" ORDER BY occ DESC\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "1caa0b73",
"metadata": {
"papermill": {
"duration": 0.275635,
"end_time": "2022-04-18T16:10:48.585268",
"exception": false,
"start_time": "2022-04-18T16:10:48.309633",
"status": "completed"
},
"tags": []
},
"source": [
"## UDF"
]
},
{
"cell_type": "markdown",
"id": "aed6ecff",
"metadata": {
"papermill": {
"duration": 0.277879,
"end_time": "2022-04-18T16:10:49.139463",
"exception": false,
"start_time": "2022-04-18T16:10:48.861584",
"status": "completed"
},
"tags": []
},
"source": [
"What comes in is a regular Python function, and what goes out is a function promoted to work on PySpark columns.\n"
]
},
{
"cell_type": "markdown",
"id": "5ed788bc",
"metadata": {
"papermill": {
"duration": 0.307955,
"end_time": "2022-04-18T16:10:49.727618",
"exception": false,
"start_time": "2022-04-18T16:10:49.419663",
"status": "completed"
},
"tags": []
},
"source": [
"PySpark provides a udf() function in the pyspark.sql.functions module to promote Python functions to their UDF equivalents. The function takes two parameters:\n",
"- The function you want to promote\n",
"- The return type of the generated UDF\n",
"\n",
"In order to make sure the input and output types are compatible. It is recommended to explicitly specify input-output types for python functions."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "752b4e2c",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:10:50.362023Z",
"iopub.status.busy": "2022-04-18T16:10:50.361031Z",
"iopub.status.idle": "2022-04-18T16:10:50.365537Z",
"shell.execute_reply": "2022-04-18T16:10:50.364624Z",
"shell.execute_reply.started": "2022-04-18T14:29:17.932933Z"
},
"papermill": {
"duration": 0.299517,
"end_time": "2022-04-18T16:10:50.365732",
"exception": false,
"start_time": "2022-04-18T16:10:50.066215",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- step: integer (nullable = true)\n",
" |-- type: string (nullable = true)\n",
" |-- amount: double (nullable = true)\n",
" |-- nameOrig: string (nullable = true)\n",
" |-- oldBalanceOrig: double (nullable = true)\n",
" |-- newBalanceOrig: double (nullable = true)\n",
" |-- nameDest: string (nullable = true)\n",
" |-- oldBalanceDest: double (nullable = true)\n",
" |-- newBalanceDest: double (nullable = true)\n",
" |-- isFraud: integer (nullable = true)\n",
" |-- isFlaggedFraud: integer (nullable = true)\n",
"\n"
]
}
],
"source": [
"df.printSchema()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "699a6a65",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:10:50.938677Z",
"iopub.status.busy": "2022-04-18T16:10:50.937948Z",
"iopub.status.idle": "2022-04-18T16:10:50.939680Z",
"shell.execute_reply": "2022-04-18T16:10:50.940267Z",
"shell.execute_reply.started": "2022-04-18T14:29:17.943460Z"
},
"papermill": {
"duration": 0.285892,
"end_time": "2022-04-18T16:10:50.940446",
"exception": false,
"start_time": "2022-04-18T16:10:50.654554",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"def classify_tier(amount:float) -> int:\n",
" if amount < 500:\n",
" return 0\n",
" if amount < 10000:\n",
" return 1\n",
" if amount < 100000:\n",
" return 2\n",
" if amount < 1000000:\n",
" return 3\n",
" return 4\n",
"\n",
"\n",
"classifyTier = F.udf(classify_tier, T.ByteType())"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "fcd36e57",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:10:51.499152Z",
"iopub.status.busy": "2022-04-18T16:10:51.498463Z",
"iopub.status.idle": "2022-04-18T16:11:11.741140Z",
"shell.execute_reply": "2022-04-18T16:11:11.741655Z",
"shell.execute_reply.started": "2022-04-18T14:29:17.956293Z"
},
"papermill": {
"duration": 20.522851,
"end_time": "2022-04-18T16:11:11.741825",
"exception": false,
"start_time": "2022-04-18T16:10:51.218974",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 61:===========================================> (3 + 1) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+----+\n",
"| nameOrig|tier|\n",
"+-----------+----+\n",
"|C1495608502| 4|\n",
"|C1321115948| 4|\n",
"| C476579021| 4|\n",
"|C1520267010| 4|\n",
"| C106297322| 4|\n",
"|C1464177809| 4|\n",
"| C355885103| 4|\n",
"|C1057507014| 4|\n",
"|C1419332030| 4|\n",
"|C2007599722| 4|\n",
"+-----------+----+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"df.select('nameOrig', classifyTier(df.amount).alias('tier')).orderBy('tier', ascending=False).show(10)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "e08fdb53",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:11:12.409479Z",
"iopub.status.busy": "2022-04-18T16:11:12.408716Z",
"iopub.status.idle": "2022-04-18T16:11:31.526739Z",
"shell.execute_reply": "2022-04-18T16:11:31.527275Z",
"shell.execute_reply.started": "2022-04-18T14:29:33.564186Z"
},
"papermill": {
"duration": 19.406816,
"end_time": "2022-04-18T16:11:31.527451",
"exception": false,
"start_time": "2022-04-18T16:11:12.120635",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 62:==============> (1 + 3) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+----+\n",
"| nameOrig|tier|\n",
"+-----------+----+\n",
"| C263860433| 4|\n",
"| C306269750| 4|\n",
"|C1611915976| 4|\n",
"|C1387188921| 4|\n",
"| C300262358| 4|\n",
"| C389879985| 4|\n",
"|C1907016309| 4|\n",
"|C1046638041| 4|\n",
"|C1543404166| 4|\n",
"|C1155108056| 4|\n",
"+-----------+----+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"spark.udf.register('classifyTier', classify_tier)\n",
"spark.sql('''\n",
" SELECT nameOrig, classifyTier(amount) tier\n",
" FROM df\n",
" ORDER BY tier DESC \n",
"''').show(10)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "83739696",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:11:32.109699Z",
"iopub.status.busy": "2022-04-18T16:11:32.109009Z",
"iopub.status.idle": "2022-04-18T16:11:49.432511Z",
"shell.execute_reply": "2022-04-18T16:11:49.431650Z",
"shell.execute_reply.started": "2022-04-18T14:29:46.658573Z"
},
"papermill": {
"duration": 17.618618,
"end_time": "2022-04-18T16:11:49.432749",
"exception": false,
"start_time": "2022-04-18T16:11:31.814131",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 63:=============================> (2 + 2) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+------------------+\n",
"| type| trueRatio|\n",
"+--------+------------------+\n",
"|TRANSFER|0.9923420321293129|\n",
"| CASH_IN| 1.0|\n",
"|CASH_OUT|0.9981604469273743|\n",
"| PAYMENT| 1.0|\n",
"| DEBIT| 1.0|\n",
"+--------+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"@F.udf(T.BooleanType())\n",
"def isTrueFlag(predicted:int, truth:int) -> bool:\n",
" return predicted == truth\n",
"\n",
"(\n",
" df.withColumn('trueFlag', isTrueFlag(df.isFraud, df.isFlaggedFraud))\n",
" .groupBy('type')\n",
" .agg(\n",
" F.mean(F.col('trueFlag').cast('int')).alias('trueRatio')\n",
" )\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "66d19152",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:11:50.047986Z",
"iopub.status.busy": "2022-04-18T16:11:50.046951Z",
"iopub.status.idle": "2022-04-18T16:12:07.657097Z",
"shell.execute_reply": "2022-04-18T16:12:07.656303Z",
"shell.execute_reply.started": "2022-04-18T14:29:59.737086Z"
},
"papermill": {
"duration": 17.902481,
"end_time": "2022-04-18T16:12:07.657305",
"exception": false,
"start_time": "2022-04-18T16:11:49.754824",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 66:==============> (1 + 3) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+------------------+\n",
"| type| mean(trueFlag)|\n",
"+--------+------------------+\n",
"|TRANSFER|0.9923420321293129|\n",
"| CASH_IN| 1.0|\n",
"|CASH_OUT|0.9981604469273743|\n",
"| PAYMENT| 1.0|\n",
"| DEBIT| 1.0|\n",
"+--------+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"spark.udf.register('isTrueFlag', isTrueFlag)\n",
"spark.sql('''\n",
" WITH trueDf as (\n",
" SELECT type, isTrueFlag(isFraud, isFlaggedFraud) trueFlag FROM df\n",
" )\n",
" SELECT type, MEAN(INT(trueFlag)) FROM trueDf\n",
" GROUP BY type\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "d87b9414",
"metadata": {
"papermill": {
"duration": 0.290615,
"end_time": "2022-04-18T16:12:08.292062",
"exception": false,
"start_time": "2022-04-18T16:12:08.001447",
"status": "completed"
},
"tags": []
},
"source": [
"## Pandas UDF"
]
},
{
"cell_type": "markdown",
"id": "2c972974",
"metadata": {
"papermill": {
"duration": 0.286788,
"end_time": "2022-04-18T16:12:08.861746",
"exception": false,
"start_time": "2022-04-18T16:12:08.574958",
"status": "completed"
},
"tags": []
},
"source": [
"### Series to Series"
]
},
{
"cell_type": "markdown",
"id": "f88e2275",
"metadata": {
"papermill": {
"duration": 0.285137,
"end_time": "2022-04-18T16:12:09.435603",
"exception": false,
"start_time": "2022-04-18T16:12:09.150466",
"status": "completed"
},
"tags": []
},
"source": [
"In a Python UDF, when you pass column objects to your UDF, PySpark will unpack each value, perform the computation, and then return the value for each record. In a Scalar UDF, PySpark will serialize (through a library called PyArrow) each partitioned column into a pandas `Series` object. You then perform the operations on the `Series` object directly,returning a Series of the same dimension from your UDF. From an end user perspective, they are the same functionally. Because pandas is optimized for rapid data manipulation, it is preferable to use a Series to Series UDF when you can instead of using a regular Python UDF, as it’ll be much faster."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "bb803e93",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:12:10.013394Z",
"iopub.status.busy": "2022-04-18T16:12:10.012348Z",
"iopub.status.idle": "2022-04-18T16:12:10.014476Z",
"shell.execute_reply": "2022-04-18T16:12:10.014997Z",
"shell.execute_reply.started": "2022-04-18T14:30:12.264830Z"
},
"papermill": {
"duration": 0.29371,
"end_time": "2022-04-18T16:12:10.015174",
"exception": false,
"start_time": "2022-04-18T16:12:09.721464",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"# We can also promote the function to pandas as GetUserType = F.pandas_udf(get_user_type, T.StringType())\n",
"@F.pandas_udf(T.StringType())\n",
"def getUserType(name: pd.Series) -> pd.Series:\n",
" return name.str[0]\n",
"\n",
"# Unfortunately, we currently can't use pandas_udf with pure SQL"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "c044d112",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:12:10.594652Z",
"iopub.status.busy": "2022-04-18T16:12:10.593472Z",
"iopub.status.idle": "2022-04-18T16:12:25.631865Z",
"shell.execute_reply": "2022-04-18T16:12:25.632636Z",
"shell.execute_reply.started": "2022-04-18T14:30:12.273390Z"
},
"papermill": {
"duration": 15.331301,
"end_time": "2022-04-18T16:12:25.632905",
"exception": false,
"start_time": "2022-04-18T16:12:10.301604",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 69:=============================> (2 + 2) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------+------------------+-------+\n",
"|userTypeDest| avgAmount| n|\n",
"+------------+------------------+-------+\n",
"| C| 265083.4571810173|4211125|\n",
"| M|13057.604660187604|2151495|\n",
"+------------+------------------+-------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"(\n",
" df.select(getUserType(df.nameDest).alias('userTypeDest'), df.amount)\n",
" .groupBy('userTypeDest')\n",
" .agg(\n",
" F.mean('amount').alias('avgAmount'),\n",
" F.count('*').alias('n')\n",
" )\n",
" .orderBy('avgAmount', ascending=False)\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "ee3dd896",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:12:26.272722Z",
"iopub.status.busy": "2022-04-18T16:12:26.272024Z",
"iopub.status.idle": "2022-04-18T16:12:26.275335Z",
"shell.execute_reply": "2022-04-18T16:12:26.274829Z",
"shell.execute_reply.started": "2022-04-18T14:30:23.272323Z"
},
"papermill": {
"duration": 0.298723,
"end_time": "2022-04-18T16:12:26.275543",
"exception": false,
"start_time": "2022-04-18T16:12:25.976820",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"def get_change(old: pd.Series, new: pd.Series) -> pd.Series:\n",
" return (new - old).abs()\n",
"\n",
"getChange = F.pandas_udf(get_change, T.DoubleType())"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "e5a0e701",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:12:26.869531Z",
"iopub.status.busy": "2022-04-18T16:12:26.868606Z",
"iopub.status.idle": "2022-04-18T16:12:39.649569Z",
"shell.execute_reply": "2022-04-18T16:12:39.648729Z",
"shell.execute_reply.started": "2022-04-18T14:30:23.279933Z"
},
"papermill": {
"duration": 13.08578,
"end_time": "2022-04-18T16:12:39.649788",
"exception": false,
"start_time": "2022-04-18T16:12:26.564008",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 72:===========================================> (3 + 1) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+--------------------+\n",
"| type| equalRatio|\n",
"+--------+--------------------+\n",
"|TRANSFER|0.010253157668570056|\n",
"| CASH_IN| 0.11754011337226754|\n",
"|CASH_OUT|0.030618994413407822|\n",
"| PAYMENT| 0.3598637226672616|\n",
"| DEBIT| 0.08978567290982815|\n",
"+--------+--------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"(\n",
" df.select(df.type, \n",
" getChange(df.oldBalanceOrig, df.newBalanceOrig).alias('changeOrig'),\n",
" getChange(df.oldBalanceDest, df.newBalanceDest).alias('changeDest'))\n",
" .withColumn('equal', F.col('changeOrig') == F.col('changeDest'))\n",
" .groupBy('type')\n",
" .agg(\n",
" F.mean(F.col('equal').cast('int')).alias('equalRatio')\n",
" )\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "87f9a87e",
"metadata": {
"papermill": {
"duration": 0.286999,
"end_time": "2022-04-18T16:12:40.287154",
"exception": false,
"start_time": "2022-04-18T16:12:40.000155",
"status": "completed"
},
"tags": []
},
"source": [
"### Iterator of (mutiple) Series to Iterator of Series"
]
},
{
"cell_type": "markdown",
"id": "d9b1ee57",
"metadata": {
"papermill": {
"duration": 0.292316,
"end_time": "2022-04-18T16:12:40.868711",
"exception": false,
"start_time": "2022-04-18T16:12:40.576395",
"status": "completed"
},
"tags": []
},
"source": [
"Due to the distributed nature of Spark, the whole `Series` won't be fed to the udf at once, but instead, each cluster will call the udf on its own batch of data, and then aggregate the result. Iterator of Series UDFs are very useful when we have an expensive *cold start* operation you need to perform once at the beginning of the processing step."
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "858bbfe5",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:12:41.458627Z",
"iopub.status.busy": "2022-04-18T16:12:41.454453Z",
"iopub.status.idle": "2022-04-18T16:12:41.460172Z",
"shell.execute_reply": "2022-04-18T16:12:41.460763Z",
"shell.execute_reply.started": "2022-04-18T14:30:33.103250Z"
},
"papermill": {
"duration": 0.298345,
"end_time": "2022-04-18T16:12:41.460951",
"exception": false,
"start_time": "2022-04-18T16:12:41.162606",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from time import sleep\n",
"from typing import Iterator, Tuple\n",
"\n",
"\n",
"@F.pandas_udf(T.ByteType())\n",
"def getNameIdLength(name: Iterator[pd.Series]) -> Iterator[pd.Series]:\n",
" # Heavy task\n",
" # sleep(5)\n",
" \n",
" for name_batch in name:\n",
" name_len = name_batch.str.len()\n",
" name_len[~name_batch.str[0].str.isnumeric()] -= 1\n",
" yield name_len "
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "ffda19f7",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:12:42.045251Z",
"iopub.status.busy": "2022-04-18T16:12:42.044566Z",
"iopub.status.idle": "2022-04-18T16:12:57.737153Z",
"shell.execute_reply": "2022-04-18T16:12:57.736340Z",
"shell.execute_reply.started": "2022-04-18T14:30:33.114190Z"
},
"papermill": {
"duration": 15.986047,
"end_time": "2022-04-18T16:12:57.737350",
"exception": false,
"start_time": "2022-04-18T16:12:41.751303",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 75:=============================> (2 + 2) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----+------------------+\n",
"|idLen| avgAmount|\n",
"+-----+------------------+\n",
"| 4|155070.73742857145|\n",
"| 7|177477.50726081585|\n",
"| 10| 179702.4408980949|\n",
"| 9|179898.05510125632|\n",
"| 8| 181572.2097899971|\n",
"| 6|197756.81529433408|\n",
"| 5|199594.79368029739|\n",
"+-----+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"(\n",
" df.select(getNameIdLength(df.nameOrig).alias('idLen'), 'amount')\n",
" .groupBy('idLen')\n",
" .agg(F.mean('amount').alias('avgAmount'))\n",
" .orderBy('avgAmount')\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "7fa6b3c1",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:12:58.376889Z",
"iopub.status.busy": "2022-04-18T16:12:58.376106Z",
"iopub.status.idle": "2022-04-18T16:12:58.378866Z",
"shell.execute_reply": "2022-04-18T16:12:58.378359Z",
"shell.execute_reply.started": "2022-04-18T14:30:44.671010Z"
},
"papermill": {
"duration": 0.303693,
"end_time": "2022-04-18T16:12:58.379031",
"exception": false,
"start_time": "2022-04-18T16:12:58.075338",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"def amount_mismatch(values: Iterator[Tuple[pd.Series, pd.Series, pd.Series, pd.Series]]) -> Iterator[pd.Series]:\n",
" # Heavy task\n",
" # ...\n",
" \n",
" for oldOrig, newOrig, oldDest, newDest in values:\n",
" yield abs(abs(newOrig - oldOrig) - abs(newDest - oldDest))\n",
" \n",
"amountMismatch = F.pandas_udf(amount_mismatch, T.DoubleType())"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "c3634ea2",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:12:58.968598Z",
"iopub.status.busy": "2022-04-18T16:12:58.965657Z",
"iopub.status.idle": "2022-04-18T16:13:11.699122Z",
"shell.execute_reply": "2022-04-18T16:13:11.698473Z",
"shell.execute_reply.started": "2022-04-18T14:30:44.680669Z"
},
"papermill": {
"duration": 13.03029,
"end_time": "2022-04-18T16:13:11.699285",
"exception": false,
"start_time": "2022-04-18T16:12:58.668995",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 78:===========================================> (3 + 1) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+------------------+\n",
"| type| avgMismatch|\n",
"+--------+------------------+\n",
"|TRANSFER| 968056.4538892006|\n",
"|CASH_OUT|170539.39652580014|\n",
"| CASH_IN| 50038.95466155722|\n",
"| DEBIT| 25567.53969902471|\n",
"| PAYMENT| 6378.936662041953|\n",
"+--------+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"(\n",
" df.select(\n",
" df.type,\n",
" amountMismatch(df.oldBalanceOrig, df.newBalanceOrig, df.oldBalanceDest, df.newBalanceDest).alias('mismatch')\n",
" )\n",
" .groupBy('type')\n",
" .agg(\n",
" F.mean('mismatch').alias('avgMismatch')\n",
" )\n",
" .orderBy('avgMismatch', ascending=False)\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "eda2c20a",
"metadata": {
"papermill": {
"duration": 0.292788,
"end_time": "2022-04-18T16:13:12.314441",
"exception": false,
"start_time": "2022-04-18T16:13:12.021653",
"status": "completed"
},
"tags": []
},
"source": [
"### Group aggregate UDF"
]
},
{
"cell_type": "markdown",
"id": "eec16af6",
"metadata": {
"papermill": {
"duration": 0.296477,
"end_time": "2022-04-18T16:13:12.906654",
"exception": false,
"start_time": "2022-04-18T16:13:12.610177",
"status": "completed"
},
"tags": []
},
"source": [
"Group aggregate UDF, also known as the Series to Scalar UDF, distills the `Series` received as input to a single value."
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "1163b66b",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:13:13.499975Z",
"iopub.status.busy": "2022-04-18T16:13:13.499285Z",
"iopub.status.idle": "2022-04-18T16:13:13.502097Z",
"shell.execute_reply": "2022-04-18T16:13:13.501495Z",
"shell.execute_reply.started": "2022-04-18T14:30:53.777697Z"
},
"papermill": {
"duration": 0.302711,
"end_time": "2022-04-18T16:13:13.502251",
"exception": false,
"start_time": "2022-04-18T16:13:13.199540",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"@F.pandas_udf(T.DoubleType())\n",
"def GetStdDeviation(series: pd.Series) -> float:\n",
" return series.std()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "085b9c8a",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:13:14.087578Z",
"iopub.status.busy": "2022-04-18T16:13:14.086917Z",
"iopub.status.idle": "2022-04-18T16:13:25.748590Z",
"shell.execute_reply": "2022-04-18T16:13:25.747643Z",
"shell.execute_reply.started": "2022-04-18T14:30:53.784736Z"
},
"papermill": {
"duration": 11.956113,
"end_time": "2022-04-18T16:13:25.748814",
"exception": false,
"start_time": "2022-04-18T16:13:13.792701",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 83:=============================> (2 + 2) / 4]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+------------------+\n",
"| type| var|\n",
"+--------+------------------+\n",
"|TRANSFER|1879573.5289080725|\n",
"|CASH_OUT|175329.74448347004|\n",
"| CASH_IN|126508.25527180695|\n",
"| DEBIT|13318.535518284714|\n",
"| PAYMENT|12556.450185716356|\n",
"+--------+------------------+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"(\n",
" df.groupBy('type')\n",
" .agg(\n",
" GetStdDeviation(df.amount).alias('var')\n",
" )\n",
" .orderBy('var', ascending=False)\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c8590d78",
"metadata": {
"papermill": {
"duration": 0.297739,
"end_time": "2022-04-18T16:13:26.350083",
"exception": false,
"start_time": "2022-04-18T16:13:26.052344",
"status": "completed"
},
"tags": []
},
"source": [
"### Group map UDF"
]
},
{
"cell_type": "markdown",
"id": "c515f735",
"metadata": {
"papermill": {
"duration": 0.29299,
"end_time": "2022-04-18T16:13:26.938589",
"exception": false,
"start_time": "2022-04-18T16:13:26.645599",
"status": "completed"
},
"tags": []
},
"source": [
"Unlike the group aggregate UDF, which returns a scalar value as a result over a batch, the grouped map UDF maps over each batch and returns a (pandas) data frame that gets combined back into a single (Spark) data frame\n",
"\n",
"Just like with the group aggregate UDF, we use `groupby()` to split a data frame into manageable batches but then pass our function to the `applyInPandas()` method. The method takes a function as a first argument and a schema as a second."
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "86a98259",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:13:27.543480Z",
"iopub.status.busy": "2022-04-18T16:13:27.542394Z",
"iopub.status.idle": "2022-04-18T16:13:43.921433Z",
"shell.execute_reply": "2022-04-18T16:13:43.920512Z",
"shell.execute_reply.started": "2022-04-18T14:31:02.877682Z"
},
"papermill": {
"duration": 16.681494,
"end_time": "2022-04-18T16:13:43.921681",
"exception": false,
"start_time": "2022-04-18T16:13:27.240187",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 86:> (0 + 1) / 1]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+---------+--------------------+\n",
"| type| amount| amountNorm|\n",
"+--------+---------+--------------------+\n",
"|TRANSFER| 181.0|1.929785364412691...|\n",
"|TRANSFER| 215310.3| 0.00232902269229461|\n",
"|TRANSFER|311685.89|0.003371535041334062|\n",
"|TRANSFER| 62610.8|6.772443276469881E-4|\n",
"|TRANSFER| 42712.39|4.619995945019032E-4|\n",
"|TRANSFER| 77957.68|8.432543299642404E-4|\n",
"|TRANSFER| 17231.46|1.863677235062513...|\n",
"|TRANSFER| 78766.03|8.519983994671721E-4|\n",
"|TRANSFER|224606.64|0.002429582898990...|\n",
"|TRANSFER|125872.53|0.001361558008596...|\n",
"+--------+---------+--------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"def normalize_by_type(data: pd.DataFrame) -> pd.DataFrame:\n",
" result = data[['type', 'amount']].copy()\n",
" maxVal = result['amount'].max()\n",
" minVal = result['amount'].min()\n",
" if maxVal == minVal:\n",
" result['amountNorm'] = 0.5\n",
" else:\n",
" result['amountNorm'] = (result['amount'] - minVal) / (maxVal - minVal)\n",
" return result\n",
"\n",
"# We can use the SQL version of schema: \n",
"# schema = 'type string, amount double, amountNorm double'\n",
"schema = T.StructType([\n",
" T.StructField('type', T.StringType()),\n",
" T.StructField('amount', T.DoubleType()),\n",
" T.StructField('amountNorm', T.DoubleType())\n",
"])\n",
"\n",
"(\n",
" df.groupBy('type')\n",
" .applyInPandas(normalize_by_type, schema)\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c58c06b5",
"metadata": {
"papermill": {
"duration": 0.300063,
"end_time": "2022-04-18T16:13:44.529623",
"exception": false,
"start_time": "2022-04-18T16:13:44.229560",
"status": "completed"
},
"tags": []
},
"source": [
"We don't need to promote our function as an UDF explicitly. PySpark will still run smoothly when we call `applyInPandas` function.`applyInPandas` was newly introduced in PySpark 3. In older versions, we uses the `pandas_udf()` and passes the return schema as an argument, and we would use the `apply()` method here."
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "8134b7df",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:13:45.144998Z",
"iopub.status.busy": "2022-04-18T16:13:45.144310Z",
"iopub.status.idle": "2022-04-18T16:14:00.270053Z",
"shell.execute_reply": "2022-04-18T16:14:00.269144Z",
"shell.execute_reply.started": "2022-04-18T14:31:15.984659Z"
},
"papermill": {
"duration": 15.437474,
"end_time": "2022-04-18T16:14:00.270271",
"exception": false,
"start_time": "2022-04-18T16:13:44.832797",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/pyspark/sql/pandas/group_ops.py:84: UserWarning: It is preferred to use 'applyInPandas' over this API. This API will be deprecated in the future releases. See SPARK-28264 for more details.\n",
" \"more details.\", UserWarning)\n",
"[Stage 89:> (0 + 1) / 1]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------+---------+--------------------+\n",
"| type| amount| amountNorm|\n",
"+--------+---------+--------------------+\n",
"|TRANSFER| 181.0|1.929785364412691...|\n",
"|TRANSFER| 215310.3| 0.00232902269229461|\n",
"|TRANSFER|311685.89|0.003371535041334062|\n",
"|TRANSFER| 62610.8|6.772443276469881E-4|\n",
"|TRANSFER| 42712.39|4.619995945019032E-4|\n",
"|TRANSFER| 77957.68|8.432543299642404E-4|\n",
"|TRANSFER| 17231.46|1.863677235062513...|\n",
"|TRANSFER| 78766.03|8.519983994671721E-4|\n",
"|TRANSFER|224606.64|0.002429582898990...|\n",
"|TRANSFER|125872.53|0.001361558008596...|\n",
"+--------+---------+--------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"# It is preferred to use 'applyInPandas' over this API (in Spark 3). \n",
"# This API will be deprecated in the future releases. See SPARK-28264 for more details.\n",
"# As will be deprecated soon, type hint inference í not supported. So, we have to specify PandasUDFType explicitly\n",
"NormalizeByType = F.pandas_udf(normalize_by_type, schema, F.PandasUDFType.GROUPED_MAP)\n",
"(\n",
" df.groupBy('type')\n",
" .apply(NormalizeByType)\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "869cfd60",
"metadata": {
"papermill": {
"duration": 0.299906,
"end_time": "2022-04-18T16:14:00.878261",
"exception": false,
"start_time": "2022-04-18T16:14:00.578355",
"status": "completed"
},
"tags": []
},
"source": [
"### Iterator of DataFrame to Iterator of DataFrame"
]
},
{
"cell_type": "markdown",
"id": "22c743b5",
"metadata": {
"papermill": {
"duration": 0.30191,
"end_time": "2022-04-18T16:14:01.483172",
"exception": false,
"start_time": "2022-04-18T16:14:01.181262",
"status": "completed"
},
"tags": []
},
"source": [
"A combination of Iterator of multiple Series to Iterator of Series UDF and Group map UDF. This is useful when we want to process the whole `DataFrame` instead of just some columns."
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "0aea46bb",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:02.099006Z",
"iopub.status.busy": "2022-04-18T16:14:02.098042Z",
"iopub.status.idle": "2022-04-18T16:14:02.474912Z",
"shell.execute_reply": "2022-04-18T16:14:02.475667Z",
"shell.execute_reply.started": "2022-04-18T14:31:28.201972Z"
},
"papermill": {
"duration": 0.691102,
"end_time": "2022-04-18T16:14:02.475904",
"exception": false,
"start_time": "2022-04-18T16:14:01.784802",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+-----------+--------+--------+------------------+\n",
"| nameOrig| nameDest| type| amount| diff|\n",
"+-----------+-----------+--------+--------+------------------+\n",
"|C1231006815|M1979787155| PAYMENT| 9839.64| 9839.640000000014|\n",
"|C1666544295|M2044282225| PAYMENT| 1864.28|1864.2799999999988|\n",
"|C1305486145| C553264065|TRANSFER| 181.0| 181.0|\n",
"| C840083671| C38997010|CASH_OUT| 181.0| 21001.0|\n",
"|C2048537720|M1230701703| PAYMENT|11668.14| 11668.14|\n",
"| C90045638| M573487274| PAYMENT| 7817.71| 7817.709999999999|\n",
"| C154988899| M408069119| PAYMENT| 7107.77|7107.7699999999895|\n",
"|C1912850431| M633326333| PAYMENT| 7861.64| 7861.640000000014|\n",
"|C1265012928|M1176932104| PAYMENT| 4024.36| 2671.0|\n",
"| C712410124| C195600860| DEBIT| 5337.77|3788.5599999999977|\n",
"+-----------+-----------+--------+--------+------------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"def process_data(batches:Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:\n",
" for data in batches:\n",
" data['changeOrig'] = (data['newBalanceOrig'] - data['oldBalanceOrig']).abs()\n",
" data['changeDest'] = (data['newBalanceDest'] - data['oldBalanceDest']).abs()\n",
" data['diff'] = (data['changeOrig'] - data['changeDest']).abs()\n",
" yield data[['nameOrig', 'nameDest', 'type', 'amount', 'diff']]\n",
" \n",
"schema = 'nameOrig STRING, nameDest STRING, type STRING, amount DOUBLE, diff DOUBLE'\n",
"# Just like applyInPandas, mapInPandas don't require us to promote the function explicitly\n",
"df.mapInPandas(process_data, schema).show(10)"
]
},
{
"cell_type": "markdown",
"id": "17a1c9a3",
"metadata": {
"papermill": {
"duration": 0.304045,
"end_time": "2022-04-18T16:14:03.091864",
"exception": false,
"start_time": "2022-04-18T16:14:02.787819",
"status": "completed"
},
"tags": []
},
"source": [
"### Co-grouped map UDF"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee9ec5e1",
"metadata": {
"papermill": {
"duration": 0.332712,
"end_time": "2022-04-18T16:14:03.740655",
"exception": false,
"start_time": "2022-04-18T16:14:03.407943",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "c3d69218",
"metadata": {
"papermill": {
"duration": 0.317087,
"end_time": "2022-04-18T16:14:04.382574",
"exception": false,
"start_time": "2022-04-18T16:14:04.065487",
"status": "completed"
},
"tags": []
},
"source": [
"## Window function"
]
},
{
"cell_type": "markdown",
"id": "7e873583",
"metadata": {
"papermill": {
"duration": 0.316281,
"end_time": "2022-04-18T16:14:05.027445",
"exception": false,
"start_time": "2022-04-18T16:14:04.711164",
"status": "completed"
},
"tags": []
},
"source": [
"Window functions enable users to perform calculations against partitions. Unlike traditional aggregation functions, which return only a single value for each group defined in the query, window functions return a single value for each input row."
]
},
{
"cell_type": "markdown",
"id": "227b12ad",
"metadata": {
"papermill": {
"duration": 0.330329,
"end_time": "2022-04-18T16:14:05.677856",
"exception": false,
"start_time": "2022-04-18T16:14:05.347527",
"status": "completed"
},
"tags": []
},
"source": [
"To demonstrate the use of Window functions, we will use another dataset which is time-series here."
]
},
{
"cell_type": "markdown",
"id": "82ff9ff2",
"metadata": {
"papermill": {
"duration": 0.30554,
"end_time": "2022-04-18T16:14:06.308747",
"exception": false,
"start_time": "2022-04-18T16:14:06.003207",
"status": "completed"
},
"tags": []
},
"source": [
"### Window functions basic"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "21e09a94",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:06.981709Z",
"iopub.status.busy": "2022-04-18T16:14:06.980549Z",
"iopub.status.idle": "2022-04-18T16:14:07.050796Z",
"shell.execute_reply": "2022-04-18T16:14:07.051581Z",
"shell.execute_reply.started": "2022-04-18T14:31:28.650341Z"
},
"papermill": {
"duration": 0.429987,
"end_time": "2022-04-18T16:14:07.051830",
"exception": false,
"start_time": "2022-04-18T16:14:06.621843",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- date: string (nullable = true)\n",
" |-- open: double (nullable = true)\n",
" |-- high: double (nullable = true)\n",
" |-- low: double (nullable = true)\n",
" |-- close: double (nullable = true)\n",
" |-- adj_close: double (nullable = true)\n",
" |-- volume: double (nullable = true)\n",
"\n"
]
}
],
"source": [
"apple_data_path = '../input/apple-stock-data-updated-till-22jun2021/AAPL.csv'\n",
"apple_schema = T.StructType([\n",
" T.StructField('Date', T.StringType()),\n",
" T.StructField('Open', T.DoubleType()),\n",
" T.StructField('High', T.DoubleType()),\n",
" T.StructField('Low', T.DoubleType()),\n",
" T.StructField('Close', T.DoubleType()),\n",
" T.StructField('Adj Close', T.DoubleType()),\n",
" T.StructField('Volume', T.DoubleType()),\n",
"])\n",
"apple_df = spark.read.csv(apple_data_path, header=True, schema=apple_schema)\n",
"\n",
"for col in apple_df.columns:\n",
" new_col = col.lower().replace(' ', '_')\n",
" apple_df = apple_df.withColumnRenamed(col, new_col)\n",
"\n",
"apple_df.createOrReplaceTempView('apple_df')\n",
"apple_df.printSchema()"
]
},
{
"cell_type": "markdown",
"id": "39d90f53",
"metadata": {
"papermill": {
"duration": 0.306095,
"end_time": "2022-04-18T16:14:07.687517",
"exception": false,
"start_time": "2022-04-18T16:14:07.381422",
"status": "completed"
},
"tags": []
},
"source": [
"For example, we want to get the highest adjusted close price each year along with relevant information in the same day. Without window functions, it is essential to group data and use a self-join.\n",
"\n",
"While it’s not technically wrong, it can be slow and make the code look more complex than it needs to be. It also looks a little odd. Joining tables make sense when you want to link data contained into two or more tables. Joining a table with itself feels redundant."
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "402b085b",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:08.304313Z",
"iopub.status.busy": "2022-04-18T16:14:08.303514Z",
"iopub.status.idle": "2022-04-18T16:14:08.994379Z",
"shell.execute_reply": "2022-04-18T16:14:08.993678Z",
"shell.execute_reply.started": "2022-04-18T14:31:28.726232Z"
},
"papermill": {
"duration": 1.003199,
"end_time": "2022-04-18T16:14:08.994536",
"exception": false,
"start_time": "2022-04-18T16:14:07.991337",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+--------+--------+--------+--------+---------+\n",
"| date| open| close| high| low|adj_close|\n",
"+----------+--------+--------+--------+--------+---------+\n",
"|1980-12-29|0.160714|0.160714|0.161272|0.160714| 0.125622|\n",
"|1981-01-02|0.154018|0.154018|0.155134|0.154018| 0.120388|\n",
"|1982-12-07|0.149554|0.151228|0.154576|0.146205| 0.118207|\n",
"|1983-06-06|0.273996|0.280134|0.280134|0.273996| 0.218966|\n",
"|1984-05-01|0.141741|0.148438|0.148438|0.141741| 0.116026|\n",
"|1985-01-14|0.136719|0.136719|0.137835|0.136719| 0.106866|\n",
"|1986-12-05| 0.19029|0.195313|0.195313|0.189732| 0.152666|\n",
"|1987-10-05|0.522321|0.529018|0.533482|0.515625| 0.414671|\n",
"|1988-07-05|0.415179|0.421875|0.421875| 0.41183| 0.332719|\n",
"|1989-06-14| 0.4375| 0.44308|0.448661|0.430804| 0.352764|\n",
"+----------+--------+--------+--------+--------+---------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"tdf = apple_df.withColumn('year', F.substring(apple_df.date, 0, 4).cast(T.IntegerType()))\n",
"\n",
"(\n",
" tdf\n",
" .groupBy('year')\n",
" .agg(\n",
" F.max(tdf.adj_close).alias('max_adj_close'),\n",
" )\n",
" .alias('left')\n",
" .join(tdf.alias('right'), \n",
" on=(F.col('max_adj_close')==tdf.adj_close)&(F.col('left.year')==F.col('right.year')),\n",
" how='left')\n",
" .orderBy(F.col('left.year'))\n",
" .select('date', 'open', 'close', 'high', 'low', 'adj_close')\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "215b2757",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:09.625316Z",
"iopub.status.busy": "2022-04-18T16:14:09.624204Z",
"iopub.status.idle": "2022-04-18T16:14:10.121768Z",
"shell.execute_reply": "2022-04-18T16:14:10.120865Z",
"shell.execute_reply.started": "2022-04-18T14:31:29.397026Z"
},
"papermill": {
"duration": 0.813748,
"end_time": "2022-04-18T16:14:10.122009",
"exception": false,
"start_time": "2022-04-18T16:14:09.308261",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+--------+--------+--------+--------+---------+\n",
"| date| open| close| high| low|adj_close|\n",
"+----------+--------+--------+--------+--------+---------+\n",
"|1980-12-29|0.160714|0.160714|0.161272|0.160714| 0.125622|\n",
"|1981-01-02|0.154018|0.154018|0.155134|0.154018| 0.120388|\n",
"|1982-12-07|0.149554|0.151228|0.154576|0.146205| 0.118207|\n",
"|1983-06-06|0.273996|0.280134|0.280134|0.273996| 0.218966|\n",
"|1984-05-01|0.141741|0.148438|0.148438|0.141741| 0.116026|\n",
"|1985-01-14|0.136719|0.136719|0.137835|0.136719| 0.106866|\n",
"|1986-12-05| 0.19029|0.195313|0.195313|0.189732| 0.152666|\n",
"|1987-10-05|0.522321|0.529018|0.533482|0.515625| 0.414671|\n",
"|1988-07-05|0.415179|0.421875|0.421875| 0.41183| 0.332719|\n",
"|1989-06-14| 0.4375| 0.44308|0.448661|0.430804| 0.352764|\n",
"+----------+--------+--------+--------+--------+---------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"spark.sql('''\n",
" WITH\n",
" tdf AS (\n",
" SELECT *, CAST(SUBSTRING(date, 0, 4) AS INTEGER) year\n",
" FROM apple_df\n",
" ),\n",
" max_df AS (\n",
" SELECT year, MAX(adj_close) max_adj_close\n",
" FROM tdf\n",
" GROUP BY year\n",
" )\n",
" SELECT date, open, close, high, low, adj_close \n",
" FROM \n",
" max_df LEFT JOIN tdf ON tdf.year = max_df.year AND tdf.adj_close = max_df.max_adj_close\n",
" ORDER BY date\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "58d2b151",
"metadata": {
"papermill": {
"duration": 0.304513,
"end_time": "2022-04-18T16:14:10.806151",
"exception": false,
"start_time": "2022-04-18T16:14:10.501638",
"status": "completed"
},
"tags": []
},
"source": [
"Window functions also do the calculation on parts of date, not the whole dataset. But unlike traditional `group by` which return only one value for each group, window functions return one value for each input rows."
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "526dc2ee",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:11.505496Z",
"iopub.status.busy": "2022-04-18T16:14:11.504549Z",
"iopub.status.idle": "2022-04-18T16:14:12.067137Z",
"shell.execute_reply": "2022-04-18T16:14:12.068034Z",
"shell.execute_reply.started": "2022-04-18T14:31:29.835649Z"
},
"papermill": {
"duration": 0.951556,
"end_time": "2022-04-18T16:14:12.068395",
"exception": false,
"start_time": "2022-04-18T16:14:11.116839",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+--------+--------+--------+--------+---------+\n",
"| date| open| close| high| low|adj_close|\n",
"+----------+--------+--------+--------+--------+---------+\n",
"|1980-12-29|0.160714|0.160714|0.161272|0.160714| 0.125622|\n",
"|1981-01-02|0.154018|0.154018|0.155134|0.154018| 0.120388|\n",
"|1982-12-07|0.149554|0.151228|0.154576|0.146205| 0.118207|\n",
"|1983-06-06|0.273996|0.280134|0.280134|0.273996| 0.218966|\n",
"|1984-05-01|0.141741|0.148438|0.148438|0.141741| 0.116026|\n",
"|1985-01-14|0.136719|0.136719|0.137835|0.136719| 0.106866|\n",
"|1986-12-05| 0.19029|0.195313|0.195313|0.189732| 0.152666|\n",
"|1987-10-05|0.522321|0.529018|0.533482|0.515625| 0.414671|\n",
"|1988-07-05|0.415179|0.421875|0.421875| 0.41183| 0.332719|\n",
"|1989-06-14| 0.4375| 0.44308|0.448661|0.430804| 0.352764|\n",
"+----------+--------+--------+--------+--------+---------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"from pyspark.sql.window import Window\n",
"\n",
"\n",
"annual_window = Window.partitionBy('year')\n",
"\n",
"(\n",
" apple_df\n",
" .withColumn('year', F.substring(apple_df.date, 0, 4).cast(T.IntegerType()))\n",
" .withColumn('max_adj_close', F.max(apple_df.adj_close).over(annual_window))\n",
" .where(F.col('adj_close')==F.col('max_adj_close'))\n",
" .select('date', 'open', 'close', 'high', 'low', 'adj_close')\n",
").show(10)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "fdadbbc7",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:12.737788Z",
"iopub.status.busy": "2022-04-18T16:14:12.736716Z",
"iopub.status.idle": "2022-04-18T16:14:13.038105Z",
"shell.execute_reply": "2022-04-18T16:14:13.037430Z",
"shell.execute_reply.started": "2022-04-18T14:31:30.330854Z"
},
"papermill": {
"duration": 0.63397,
"end_time": "2022-04-18T16:14:13.038301",
"exception": false,
"start_time": "2022-04-18T16:14:12.404331",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+--------+--------+--------+--------+---------+\n",
"| date| open| close| high| low|adj_close|\n",
"+----------+--------+--------+--------+--------+---------+\n",
"|1980-12-29|0.160714|0.160714|0.161272|0.160714| 0.125622|\n",
"|1981-01-02|0.154018|0.154018|0.155134|0.154018| 0.120388|\n",
"|1982-12-07|0.149554|0.151228|0.154576|0.146205| 0.118207|\n",
"|1983-06-06|0.273996|0.280134|0.280134|0.273996| 0.218966|\n",
"|1984-05-01|0.141741|0.148438|0.148438|0.141741| 0.116026|\n",
"|1985-01-14|0.136719|0.136719|0.137835|0.136719| 0.106866|\n",
"|1986-12-05| 0.19029|0.195313|0.195313|0.189732| 0.152666|\n",
"|1987-10-05|0.522321|0.529018|0.533482|0.515625| 0.414671|\n",
"|1988-07-05|0.415179|0.421875|0.421875| 0.41183| 0.332719|\n",
"|1989-06-14| 0.4375| 0.44308|0.448661|0.430804| 0.352764|\n",
"+----------+--------+--------+--------+--------+---------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"spark.sql('''\n",
" WITH \n",
" tdf as (\n",
" SELECT date, open, close, high, low, adj_close,\n",
" MAX(adj_close) OVER (PARTITION BY CAST(SUBSTRING(date, 0, 4) AS INTEGER)) max_adj_close\n",
" FROM apple_df\n",
" )\n",
" SELECT date, open, close, high, low, adj_close\n",
" FROM tdf\n",
" WHERE adj_close = max_adj_close \n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "a02d70af",
"metadata": {
"papermill": {
"duration": 0.301766,
"end_time": "2022-04-18T16:14:13.650341",
"exception": false,
"start_time": "2022-04-18T16:14:13.348575",
"status": "completed"
},
"tags": []
},
"source": [
"Taking the `max` (or `min`, '`sum`, `mean`, ...) don't require windows to be ordered. There are functions that are sensitive to the intra-order of windows. They are ranking functions (`rank`, `dense_rank`, `percent_rank`, `ntile`, `row_number`, ...) or analytic functions (`lag`, `lead`, `cume_dist`, ...)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "6db165c8",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:14.257765Z",
"iopub.status.busy": "2022-04-18T16:14:14.256948Z",
"iopub.status.idle": "2022-04-18T16:14:14.745804Z",
"shell.execute_reply": "2022-04-18T16:14:14.745095Z",
"shell.execute_reply.started": "2022-04-18T14:31:30.611435Z"
},
"papermill": {
"duration": 0.798852,
"end_time": "2022-04-18T16:14:14.746032",
"exception": false,
"start_time": "2022-04-18T16:14:13.947180",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---------+----+----------+--------+--------+\n",
"| date|adj_close|rank|dense_rank| lag| lead|\n",
"+----------+---------+----+----------+--------+--------+\n",
"|1980-12-12| 0.100323| 1| 1| null|0.095089|\n",
"|1980-12-15| 0.095089| 2| 2|0.100323| 0.08811|\n",
"|1980-12-16| 0.08811| 3| 3|0.095089|0.090291|\n",
"|1980-12-17| 0.090291| 4| 4| 0.08811|0.092908|\n",
"|1980-12-18| 0.092908| 5| 5|0.090291|0.098578|\n",
"|1980-12-19| 0.098578| 6| 6|0.092908|0.103376|\n",
"|1980-12-22| 0.103376| 7| 7|0.098578|0.107739|\n",
"|1980-12-23| 0.107739| 8| 8|0.103376|0.113409|\n",
"|1980-12-24| 0.113409| 9| 9|0.107739|0.123877|\n",
"|1980-12-26| 0.123877| 10| 10|0.113409|0.125622|\n",
"+----------+---------+----+----------+--------+--------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"# Or just: ordered_annual_window = annual_window.orderBy('date')\n",
"ordered_annual_window = Window.partitionBy('year').orderBy('date')\n",
"\n",
"# Rank do not need a column as parameter. As everything was already ordered.\n",
"(\n",
" apple_df.withColumn('year', F.substring(apple_df.date, 0, 4).cast(T.IntegerType()))\n",
" .select(\n",
" 'date', 'adj_close',\n",
" F.rank().over(ordered_annual_window).alias('rank'),\n",
" F.dense_rank().over(ordered_annual_window).alias('dense_rank'),\n",
" F.lag('adj_close', 1).over(ordered_annual_window).alias('lag'),\n",
" F.lead('adj_close', 1).over(ordered_annual_window).alias('lead'),\n",
" )\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "d6df467e",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:15.398633Z",
"iopub.status.busy": "2022-04-18T16:14:15.397933Z",
"iopub.status.idle": "2022-04-18T16:14:15.703349Z",
"shell.execute_reply": "2022-04-18T16:14:15.702551Z",
"shell.execute_reply.started": "2022-04-18T14:31:30.991164Z"
},
"papermill": {
"duration": 0.614771,
"end_time": "2022-04-18T16:14:15.703545",
"exception": false,
"start_time": "2022-04-18T16:14:15.088774",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---------+----+----------+--------+--------+\n",
"| date|adj_close|rank|dense_rank| lag| lead|\n",
"+----------+---------+----+----------+--------+--------+\n",
"|1980-12-12| 0.100323| 1| 1| null|0.095089|\n",
"|1980-12-15| 0.095089| 2| 2|0.100323| 0.08811|\n",
"|1980-12-16| 0.08811| 3| 3|0.095089|0.090291|\n",
"|1980-12-17| 0.090291| 4| 4| 0.08811|0.092908|\n",
"|1980-12-18| 0.092908| 5| 5|0.090291|0.098578|\n",
"|1980-12-19| 0.098578| 6| 6|0.092908|0.103376|\n",
"|1980-12-22| 0.103376| 7| 7|0.098578|0.107739|\n",
"|1980-12-23| 0.107739| 8| 8|0.103376|0.113409|\n",
"|1980-12-24| 0.113409| 9| 9|0.107739|0.123877|\n",
"|1980-12-26| 0.123877| 10| 10|0.113409|0.125622|\n",
"+----------+---------+----+----------+--------+--------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"spark.sql('''\n",
" WITH tdf AS (\n",
" SELECT *, CAST(SUBSTRING(date, 0, 4) AS INTEGER) year FROM apple_df\n",
" )\n",
" SELECT date, adj_close, \n",
" RANK() OVER (PARTITION BY year ORDER BY date) rank,\n",
" DENSE_RANK() OVER (PARTITION BY year ORDER BY date) dense_rank,\n",
" LAG(adj_close) OVER (PARTITION BY year ORDER BY date) lag,\n",
" LEAD(adj_close) OVER (PARTITION BY year ORDER BY date) lead\n",
" FROM tdf\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "990a6b17",
"metadata": {
"papermill": {
"duration": 0.307054,
"end_time": "2022-04-18T16:14:16.349583",
"exception": false,
"start_time": "2022-04-18T16:14:16.042529",
"status": "completed"
},
"tags": []
},
"source": [
"### Flexible window functions with boundaries"
]
},
{
"cell_type": "markdown",
"id": "aeb97bdc",
"metadata": {
"papermill": {
"duration": 0.301615,
"end_time": "2022-04-18T16:14:16.951050",
"exception": false,
"start_time": "2022-04-18T16:14:16.649435",
"status": "completed"
},
"tags": []
},
"source": [
"We are able to fine-tune the boundaries of a window (respect to the current row). We can build static, growing, and unbounded windows based on rows and ranges.\n",
"\n",
"Spark provides `Window.unboundedPreceding`, `Window.unboundedFollowing`, `Window.currentRow` for the first, the last and the current row of the window respectively."
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "853376e4",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:17.562689Z",
"iopub.status.busy": "2022-04-18T16:14:17.562035Z",
"iopub.status.idle": "2022-04-18T16:14:17.806791Z",
"shell.execute_reply": "2022-04-18T16:14:17.807627Z",
"shell.execute_reply.started": "2022-04-18T14:31:31.226376Z"
},
"papermill": {
"duration": 0.553081,
"end_time": "2022-04-18T16:14:17.807907",
"exception": false,
"start_time": "2022-04-18T16:14:17.254826",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---------+-------------------+\n",
"| date|adj_close| ma10|\n",
"+----------+---------+-------------------+\n",
"|1980-12-12| 0.100323| 0.100323|\n",
"|1980-12-15| 0.095089| 0.097706|\n",
"|1980-12-16| 0.08811|0.09450733333333333|\n",
"|1980-12-17| 0.090291| 0.09345325|\n",
"|1980-12-18| 0.092908| 0.0933442|\n",
"|1980-12-19| 0.098578| 0.0942165|\n",
"|1980-12-22| 0.103376| 0.095525|\n",
"|1980-12-23| 0.107739| 0.09705175|\n",
"|1980-12-24| 0.113409|0.09886922222222222|\n",
"|1980-12-26| 0.123877| 0.10137|\n",
"+----------+---------+-------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/04/18 16:14:17 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:17 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:17 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n"
]
}
],
"source": [
"(\n",
" apple_df.select('date', 'adj_close',\n",
" F.mean('adj_close').over(Window.orderBy('date').rowsBetween(-9, Window.currentRow)).alias('ma10'))\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "2ea60dd8",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:18.482209Z",
"iopub.status.busy": "2022-04-18T16:14:18.481517Z",
"iopub.status.idle": "2022-04-18T16:14:18.670611Z",
"shell.execute_reply": "2022-04-18T16:14:18.669953Z",
"shell.execute_reply.started": "2022-04-18T14:31:31.434380Z"
},
"papermill": {
"duration": 0.492437,
"end_time": "2022-04-18T16:14:18.670755",
"exception": false,
"start_time": "2022-04-18T16:14:18.178318",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---------+-------------------+\n",
"| date|adj_close| ma10|\n",
"+----------+---------+-------------------+\n",
"|1980-12-12| 0.100323| 0.100323|\n",
"|1980-12-15| 0.095089| 0.097706|\n",
"|1980-12-16| 0.08811|0.09450733333333333|\n",
"|1980-12-17| 0.090291| 0.09345325|\n",
"|1980-12-18| 0.092908| 0.0933442|\n",
"|1980-12-19| 0.098578| 0.0942165|\n",
"|1980-12-22| 0.103376| 0.095525|\n",
"|1980-12-23| 0.107739| 0.09705175|\n",
"|1980-12-24| 0.113409|0.09886922222222222|\n",
"|1980-12-26| 0.123877| 0.10137|\n",
"+----------+---------+-------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/04/18 16:14:18 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:18 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:18 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n"
]
}
],
"source": [
"spark.sql('''\n",
" SELECT date, adj_close,\n",
" MEAN(adj_close) OVER (ORDER BY date ROWS BETWEEN 9 PRECEDING AND CURRENT ROW) ma10\n",
" FROM apple_df\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "65d94f2b",
"metadata": {
"papermill": {
"duration": 0.300691,
"end_time": "2022-04-18T16:14:19.275558",
"exception": false,
"start_time": "2022-04-18T16:14:18.974867",
"status": "completed"
},
"tags": []
},
"source": [
"The difference between `rowsBetween` and `rangeBetween` is that `rowsBetween` looks for rows while `rangeBetween` looks for values."
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "e114a30e",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:19.887171Z",
"iopub.status.busy": "2022-04-18T16:14:19.886463Z",
"iopub.status.idle": "2022-04-18T16:14:20.244473Z",
"shell.execute_reply": "2022-04-18T16:14:20.243795Z",
"shell.execute_reply.started": "2022-04-18T14:31:31.600217Z"
},
"papermill": {
"duration": 0.667651,
"end_time": "2022-04-18T16:14:20.244619",
"exception": false,
"start_time": "2022-04-18T16:14:19.576968",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/04/18 16:14:19 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:19 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:20 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---------+-------------------+\n",
"| date|adj_close| true_ma10|\n",
"+----------+---------+-------------------+\n",
"|1980-12-12| 0.100323| 0.100323|\n",
"|1980-12-15| 0.095089| 0.097706|\n",
"|1980-12-16| 0.08811|0.09450733333333333|\n",
"|1980-12-17| 0.090291| 0.09345325|\n",
"|1980-12-18| 0.092908| 0.0933442|\n",
"|1980-12-19| 0.098578| 0.0942165|\n",
"|1980-12-22| 0.103376|0.09472533333333333|\n",
"|1980-12-23| 0.107739|0.09658442857142857|\n",
"|1980-12-24| 0.113409| 0.0986875|\n",
"|1980-12-26| 0.123877|0.10431114285714285|\n",
"+----------+---------+-------------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"# This will calculate the mean within 10 days (not twenty rows). \n",
"# Remember that 10 days are not eqivalent to 10 rows because of missing dates.\n",
"( \n",
" apple_df.withColumn('date_stamp', F.datediff(apple_df.date, F.lit('1980-01-01')))\n",
" .select('date', 'adj_close',\n",
" F.mean('adj_close').over(Window.orderBy('date_stamp').rangeBetween(-9, Window.currentRow)).alias('true_ma10'))\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "9234e18d",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:20.853052Z",
"iopub.status.busy": "2022-04-18T16:14:20.851948Z",
"iopub.status.idle": "2022-04-18T16:14:21.046726Z",
"shell.execute_reply": "2022-04-18T16:14:21.047542Z",
"shell.execute_reply.started": "2022-04-18T14:31:32.016335Z"
},
"papermill": {
"duration": 0.50323,
"end_time": "2022-04-18T16:14:21.047738",
"exception": false,
"start_time": "2022-04-18T16:14:20.544508",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---------+-------------------+\n",
"| date|adj_close| true_ma10|\n",
"+----------+---------+-------------------+\n",
"|1980-12-12| 0.100323| 0.100323|\n",
"|1980-12-15| 0.095089| 0.097706|\n",
"|1980-12-16| 0.08811|0.09450733333333333|\n",
"|1980-12-17| 0.090291| 0.09345325|\n",
"|1980-12-18| 0.092908| 0.0933442|\n",
"|1980-12-19| 0.098578| 0.0942165|\n",
"|1980-12-22| 0.103376|0.09472533333333333|\n",
"|1980-12-23| 0.107739|0.09658442857142857|\n",
"|1980-12-24| 0.113409| 0.0986875|\n",
"|1980-12-26| 0.123877|0.10431114285714285|\n",
"+----------+---------+-------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/04/18 16:14:20 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:20 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:20 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n"
]
}
],
"source": [
"spark.sql('''\n",
" WITH tdf as (\n",
" SELECT date, adj_close, DATEDIFF(date, \"1980-01-01\") date_stamp FROM apple_df\n",
" )\n",
" SELECT date, adj_close, MEAN(adj_close) OVER (ORDER BY date_stamp RANGE BETWEEN 9 PRECEDING AND CURRENT ROW) true_ma10\n",
" FROM tdf\n",
"''').show(10)"
]
},
{
"cell_type": "markdown",
"id": "aeeb863e",
"metadata": {
"papermill": {
"duration": 0.306702,
"end_time": "2022-04-18T16:14:21.658057",
"exception": false,
"start_time": "2022-04-18T16:14:21.351355",
"status": "completed"
},
"tags": []
},
"source": [
"### Window functions and pandas udf"
]
},
{
"cell_type": "markdown",
"id": "562a101e",
"metadata": {
"papermill": {
"duration": 0.306016,
"end_time": "2022-04-18T16:14:22.272368",
"exception": false,
"start_time": "2022-04-18T16:14:21.966352",
"status": "completed"
},
"tags": []
},
"source": [
"The recipe for applying a pandas UDF is very simple:\n",
"- We need to use a Series to Scalar UDF (or a group aggregate UDF). PySpark will apply the UDF to every window (once per record) and put the (scalar) value as a result.\n",
"- A UDF over unbounded window frames is only supported by Spark 2.4 and above.\n",
"- A UDF over bounded window frames is only supported by Spark 3.0 and above."
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "dca47c1c",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:22.887083Z",
"iopub.status.busy": "2022-04-18T16:14:22.886415Z",
"iopub.status.idle": "2022-04-18T16:14:23.711006Z",
"shell.execute_reply": "2022-04-18T16:14:23.710423Z",
"shell.execute_reply.started": "2022-04-18T14:31:32.268452Z"
},
"papermill": {
"duration": 1.136672,
"end_time": "2022-04-18T16:14:23.711156",
"exception": false,
"start_time": "2022-04-18T16:14:22.574484",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22/04/18 16:14:22 WARN WindowInPandasExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:22 WARN WindowInPandasExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n",
"22/04/18 16:14:23 WARN WindowInPandasExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----------+---------+-------------------+\n",
"| date|adj_close| ema10|\n",
"+----------+---------+-------------------+\n",
"|1980-12-12| 0.100323| 0.100323|\n",
"|1980-12-15| 0.095089| 0.0982294|\n",
"|1980-12-16| 0.08811|0.09418163999999998|\n",
"|1980-12-17| 0.090291|0.09262538399999998|\n",
"|1980-12-18| 0.092908|0.09273843039999999|\n",
"|1980-12-19| 0.098578|0.09507425823999999|\n",
"|1980-12-22| 0.103376| 0.098394954944|\n",
"|1980-12-23| 0.107739|0.10213257296639999|\n",
"|1980-12-24| 0.113409| 0.10664314377984|\n",
"|1980-12-26| 0.123877| 0.113536686267904|\n",
"+----------+---------+-------------------+\n",
"only showing top 10 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"@F.pandas_udf(T.DoubleType())\n",
"def Ema(values: pd.Series) -> float:\n",
" alpha = 0.4\n",
" ema = values.iloc[0]\n",
" for val in values:\n",
" ema = alpha * val + (1 - alpha) * ema\n",
" return ema\n",
"\n",
"(\n",
" apple_df.select(\n",
" 'date', 'adj_close',\n",
" Ema(apple_df.adj_close).over(Window.orderBy('date').rowsBetween(-9, Window.currentRow)).alias('ema10')\n",
" )\n",
" .show(10)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "6724f83a",
"metadata": {
"papermill": {
"duration": 0.307028,
"end_time": "2022-04-18T16:14:24.326415",
"exception": false,
"start_time": "2022-04-18T16:14:24.019387",
"status": "completed"
},
"tags": []
},
"source": [
"## Machine learning with PySpark"
]
},
{
"cell_type": "markdown",
"id": "d39c08e9",
"metadata": {
"papermill": {
"duration": 0.325156,
"end_time": "2022-04-18T16:14:24.961027",
"exception": false,
"start_time": "2022-04-18T16:14:24.635871",
"status": "completed"
},
"tags": []
},
"source": [
"Transformers and estimators are the two main components of ML pipelines. \n",
"- Transformers are objects that, through a transform() method, modify a data frame based on a set of `Param`s that drives its behavior. We use a transformer stage when we want to deterministically transform a data frame.\n",
"- Estimators are objects that, through a `fit()` method, take a data frame and return a fully parameterized transformer called a model. We use an estimator stage when we want to transform a data frame using a data-dependent transformer."
]
},
{
"cell_type": "markdown",
"id": "7205e24c",
"metadata": {
"papermill": {
"duration": 0.303981,
"end_time": "2022-04-18T16:14:25.578020",
"exception": false,
"start_time": "2022-04-18T16:14:25.274039",
"status": "completed"
},
"tags": []
},
"source": [
"In this section, we will use the customer classification tabular dataset, a simple dataset for building a machine learning model. Note that this is a pyspark brief tutorial, not a machine learning course. So, we won't dive deep into machine learning."
]
},
{
"cell_type": "markdown",
"id": "399b6e89",
"metadata": {
"papermill": {
"duration": 0.311064,
"end_time": "2022-04-18T16:14:26.224275",
"exception": false,
"start_time": "2022-04-18T16:14:25.913211",
"status": "completed"
},
"tags": []
},
"source": [
"### Data overview"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "35390c65",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:26.840088Z",
"iopub.status.busy": "2022-04-18T16:14:26.839257Z",
"iopub.status.idle": "2022-04-18T16:14:27.072784Z",
"shell.execute_reply": "2022-04-18T16:14:27.071945Z",
"shell.execute_reply.started": "2022-04-18T14:31:33.484917Z"
},
"papermill": {
"duration": 0.542997,
"end_time": "2022-04-18T16:14:27.073019",
"exception": false,
"start_time": "2022-04-18T16:14:26.530022",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- region: integer (nullable = true)\n",
" |-- tenure: integer (nullable = true)\n",
" |-- age: integer (nullable = true)\n",
" |-- income: integer (nullable = true)\n",
" |-- marital: integer (nullable = true)\n",
" |-- address: integer (nullable = true)\n",
" |-- ed: integer (nullable = true)\n",
" |-- employ: integer (nullable = true)\n",
" |-- retire: integer (nullable = true)\n",
" |-- gender: integer (nullable = true)\n",
" |-- reside: integer (nullable = true)\n",
" |-- custcat: string (nullable = true)\n",
"\n"
]
}
],
"source": [
"cust_data_path = '../input/customersegmentation/Telecust1.csv'\n",
"\n",
"cust_df = spark.read.csv(cust_data_path, inferSchema=True, header=True)\n",
"cust_df = cust_df.toDF(*[col.lower().replace('-', '_') for col in cust_df.columns])\n",
"cust_df.printSchema()"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "91499a26",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:27.719854Z",
"iopub.status.busy": "2022-04-18T16:14:27.718857Z",
"iopub.status.idle": "2022-04-18T16:14:27.863448Z",
"shell.execute_reply": "2022-04-18T16:14:27.862678Z",
"shell.execute_reply.started": "2022-04-18T14:31:33.744726Z"
},
"papermill": {
"duration": 0.465208,
"end_time": "2022-04-18T16:14:27.863617",
"exception": false,
"start_time": "2022-04-18T16:14:27.398409",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+------+---+------+-------+-------+---+------+------+------+------+-------+\n",
"|region|tenure|age|income|marital|address| ed|employ|retire|gender|reside|custcat|\n",
"+------+------+---+------+-------+-------+---+------+------+------+------+-------+\n",
"| 2| 13| 44| 64| 1| 9| 4| 5| 0| 0| 2| A|\n",
"| 3| 11| 33| 136| 1| 7| 5| 5| 0| 0| 6| D|\n",
"| 3| 68| 52| 116| 1| 24| 1| 29| 0| 1| 2| C|\n",
"| 2| 33| 33| 33| 0| 12| 2| 0| 0| 1| 1| A|\n",
"| 2| 23| 30| 30| 1| 9| 1| 2| 0| 0| 4| C|\n",
"| 2| 41| 39| 78| 0| 17| 2| 16| 0| 1| 1| C|\n",
"| 3| 45| 22| 19| 1| 2| 2| 4| 0| 1| 5| B|\n",
"| 2| 38| 35| 76| 0| 5| 2| 10| 0| 0| 3| D|\n",
"| 3| 45| 59| 166| 1| 7| 4| 31| 0| 0| 5| C|\n",
"| 1| 68| 41| 72| 1| 21| 1| 22| 0| 0| 3| B|\n",
"+------+------+---+------+-------+-------+---+------+------+------+------+-------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"cust_df.show(10)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "6697ada2",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:28.510849Z",
"iopub.status.busy": "2022-04-18T16:14:28.510025Z",
"iopub.status.idle": "2022-04-18T16:14:29.029443Z",
"shell.execute_reply": "2022-04-18T16:14:29.028727Z",
"shell.execute_reply.started": "2022-04-18T14:31:34.087672Z"
},
"papermill": {
"duration": 0.833548,
"end_time": "2022-04-18T16:14:29.029604",
"exception": false,
"start_time": "2022-04-18T16:14:28.196056",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+------+---+------+-------+-------+---+------+------+------+------+-------+\n",
"|region|tenure|age|income|marital|address| ed|employ|retire|gender|reside|custcat|\n",
"+------+------+---+------+-------+-------+---+------+------+------+------+-------+\n",
"| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0|\n",
"+------+------+---+------+-------+-------+---+------+------+------+------+-------+\n",
"\n"
]
}
],
"source": [
"# Number of nan values\n",
"null_df = cust_df.select(*[F.count(F.when(F.col(c).isNull() | F.isnan(F.col(c)), c)).alias(c) for c in cust_df.columns])\n",
"null_df.show()"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "aba287c2",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:29.657279Z",
"iopub.status.busy": "2022-04-18T16:14:29.656305Z",
"iopub.status.idle": "2022-04-18T16:14:29.658776Z",
"shell.execute_reply": "2022-04-18T16:14:29.658198Z",
"shell.execute_reply.started": "2022-04-18T14:31:34.504774Z"
},
"papermill": {
"duration": 0.314727,
"end_time": "2022-04-18T16:14:29.658944",
"exception": false,
"start_time": "2022-04-18T16:14:29.344217",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"categorical_cols = ['region', 'marital', 'address', 'ed', 'retire', 'gender', 'reside', 'custcat']\n",
"numeric_cols = ['tenure', 'age', 'income', 'employ']"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "66e0df1d",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:30.280453Z",
"iopub.status.busy": "2022-04-18T16:14:30.279722Z",
"iopub.status.idle": "2022-04-18T16:14:30.716483Z",
"shell.execute_reply": "2022-04-18T16:14:30.717330Z",
"shell.execute_reply.started": "2022-04-18T14:31:34.514284Z"
},
"papermill": {
"duration": 0.749524,
"end_time": "2022-04-18T16:14:30.717595",
"exception": false,
"start_time": "2022-04-18T16:14:29.968071",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+-------+-------+---+------+------+------+-------+\n",
"|region|marital|address| ed|retire|gender|reside|custcat|\n",
"+------+-------+-------+---+------+------+------+-------+\n",
"| 3| 2| 50| 5| 2| 2| 8| 4|\n",
"+------+-------+-------+---+------+------+------+-------+\n",
"\n"
]
}
],
"source": [
"# Number of unique values\n",
"unique_cnt_df = cust_df.select(*[F.countDistinct(c).alias(c) for c in categorical_cols])\n",
"unique_cnt_df.show()"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "05fcdabb",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:31.403341Z",
"iopub.status.busy": "2022-04-18T16:14:31.402628Z",
"iopub.status.idle": "2022-04-18T16:14:32.521522Z",
"shell.execute_reply": "2022-04-18T16:14:32.522071Z",
"shell.execute_reply.started": "2022-04-18T14:31:34.931084Z"
},
"papermill": {
"duration": 1.44068,
"end_time": "2022-04-18T16:14:32.522272",
"exception": false,
"start_time": "2022-04-18T16:14:31.081592",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"# Plot data distribution of a column\n",
"from matplotlib import pyplot as plt\n",
"import seaborn as sns\n",
"from itertools import product\n",
"%matplotlib inline\n",
"plt.rcParams['figure.dpi'] = 150\n",
"\n",
"\n",
"def plot_column_dist(df, col, ax=None):\n",
" if ax is None:\n",
" ax = plt.gca()\n",
" agg_df = df.groupBy(col).agg(F.count('*').alias('count')).toPandas()\n",
" agg_df = agg_df.sort_values('count', ascending=False)\n",
" sns.barplot(data=agg_df, x=col, y='count',\n",
" color='#7db0bc', ax=ax)\n",
" \n",
"def plot_columns_dist(df, cols, figrows, figcols, size=(12, 6)):\n",
" assert len(cols) <= figrows * figcols\n",
" fig, axs = plt.subplots(figrows, figcols, figsize=size)\n",
" if figrows == 1:\n",
" axs = [axs]\n",
" if figcols == 1:\n",
" axs = [[ax] for ax in axs]\n",
" for i, j in product(range(figrows), range(figcols)):\n",
" if i*figcols + j >= len(cols):\n",
" continue\n",
" col = cols[i*figcols + j]\n",
" plot_column_dist(df, col, ax=axs[i][j])\n",
" axs[i][j].set_xticklabels(axs[i][j].get_xticklabels(), rotation=30)\n",
" return fig"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "9a1ec38d",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:33.152183Z",
"iopub.status.busy": "2022-04-18T16:14:33.151512Z",
"iopub.status.idle": "2022-04-18T16:14:36.950418Z",
"shell.execute_reply": "2022-04-18T16:14:36.949852Z",
"shell.execute_reply.started": "2022-04-18T14:31:35.857651Z"
},
"papermill": {
"duration": 4.119302,
"end_time": "2022-04-18T16:14:36.950561",
"exception": false,
"start_time": "2022-04-18T16:14:32.831259",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1800x900 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1800x900 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1800x900 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1800x900 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"i = 0\n",
"while i < len(categorical_cols):\n",
" cols = categorical_cols[i:i+2]\n",
" plot_columns_dist(cust_df, cols, 1, 2).show()\n",
" i += 2"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "a84416ed",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:37.593252Z",
"iopub.status.busy": "2022-04-18T16:14:37.592327Z",
"iopub.status.idle": "2022-04-18T16:14:40.997794Z",
"shell.execute_reply": "2022-04-18T16:14:40.997244Z",
"shell.execute_reply.started": "2022-04-18T14:31:39.363870Z"
},
"papermill": {
"duration": 3.73122,
"end_time": "2022-04-18T16:14:40.997980",
"exception": false,
"start_time": "2022-04-18T16:14:37.266760",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+-----+-----+-----+-----+\n",
"|region|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+------+-----+-----+-----+-----+\n",
"| 1| 75| 78| 95| 74|\n",
"| 2| 92| 69| 92| 81|\n",
"| 3| 99| 70| 94| 81|\n",
"+------+-----+-----+-----+-----+\n",
"\n",
"+------+-----+-----+-----+-----+\n",
"|tenure|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+------+-----+-----+-----+-----+\n",
"| 1| 10| 1| 2| 0|\n",
"| 2| 3| 0| 2| 2|\n",
"| 3| 11| 1| 1| 7|\n",
"| 4| 10| 0| 5| 4|\n",
"| 5| 6| 2| 7| 4|\n",
"| 6| 5| 2| 6| 2|\n",
"| 7| 12| 0| 2| 4|\n",
"| 8| 5| 2| 3| 4|\n",
"| 9| 11| 1| 1| 2|\n",
"| 10| 6| 4| 2| 6|\n",
"+------+-----+-----+-----+-----+\n",
"only showing top 10 rows\n",
"\n",
"+---+-----+-----+-----+-----+\n",
"|age|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+---+-----+-----+-----+-----+\n",
"| 18| 0| 0| 1| 0|\n",
"| 19| 2| 2| 0| 0|\n",
"| 20| 4| 1| 2| 3|\n",
"| 21| 4| 2| 2| 0|\n",
"| 22| 4| 5| 4| 2|\n",
"| 23| 4| 1| 5| 6|\n",
"| 24| 9| 1| 7| 3|\n",
"| 25| 9| 4| 5| 5|\n",
"| 26| 9| 2| 3| 7|\n",
"| 27| 7| 7| 0| 10|\n",
"+---+-----+-----+-----+-----+\n",
"only showing top 10 rows\n",
"\n",
"+------+-----+-----+-----+-----+\n",
"|income|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+------+-----+-----+-----+-----+\n",
"| 9| 2| 2| 3| 0|\n",
"| 10| 1| 0| 2| 0|\n",
"| 11| 2| 0| 0| 0|\n",
"| 12| 1| 0| 1| 1|\n",
"| 13| 0| 1| 1| 0|\n",
"| 14| 2| 1| 2| 2|\n",
"| 15| 1| 2| 3| 2|\n",
"| 16| 2| 1| 0| 1|\n",
"| 17| 5| 3| 4| 1|\n",
"| 18| 7| 2| 1| 5|\n",
"+------+-----+-----+-----+-----+\n",
"only showing top 10 rows\n",
"\n",
"+-------+-----+-----+-----+-----+\n",
"|marital|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+-------+-----+-----+-----+-----+\n",
"| 0| 155| 102| 142| 106|\n",
"| 1| 111| 115| 139| 130|\n",
"+-------+-----+-----+-----+-----+\n",
"\n",
"+-------+-----+-----+-----+-----+\n",
"|address|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+-------+-----+-----+-----+-----+\n",
"| 0| 10| 12| 17| 17|\n",
"| 1| 21| 17| 12| 18|\n",
"| 2| 21| 14| 15| 16|\n",
"| 3| 24| 12| 13| 12|\n",
"| 4| 18| 10| 18| 15|\n",
"| 5| 18| 9| 12| 11|\n",
"| 6| 10| 7| 10| 9|\n",
"| 7| 21| 5| 15| 12|\n",
"| 8| 8| 12| 9| 10|\n",
"| 9| 19| 7| 7| 8|\n",
"+-------+-----+-----+-----+-----+\n",
"only showing top 10 rows\n",
"\n",
"+---+-----+-----+-----+-----+\n",
"| ed|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+---+-----+-----+-----+-----+\n",
"| 1| 75| 29| 91| 9|\n",
"| 2| 83| 54| 98| 52|\n",
"| 3| 53| 53| 54| 49|\n",
"| 4| 46| 59| 34| 95|\n",
"| 5| 9| 22| 4| 31|\n",
"+---+-----+-----+-----+-----+\n",
"\n",
"+------+-----+-----+-----+-----+\n",
"|employ|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+------+-----+-----+-----+-----+\n",
"| 0| 34| 24| 24| 24|\n",
"| 1| 19| 14| 13| 20|\n",
"| 2| 23| 9| 12| 15|\n",
"| 3| 18| 12| 5| 15|\n",
"| 4| 13| 12| 13| 14|\n",
"| 5| 20| 9| 14| 11|\n",
"| 6| 11| 12| 10| 11|\n",
"| 7| 16| 10| 12| 10|\n",
"| 8| 9| 7| 11| 11|\n",
"| 9| 7| 11| 11| 10|\n",
"+------+-----+-----+-----+-----+\n",
"only showing top 10 rows\n",
"\n",
"+------+-----+-----+-----+-----+\n",
"|retire|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+------+-----+-----+-----+-----+\n",
"| 0| 255| 210| 259| 229|\n",
"| 1| 11| 7| 22| 7|\n",
"+------+-----+-----+-----+-----+\n",
"\n",
"+------+-----+-----+-----+-----+\n",
"|gender|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+------+-----+-----+-----+-----+\n",
"| 0| 131| 98| 139| 115|\n",
"| 1| 135| 119| 142| 121|\n",
"+------+-----+-----+-----+-----+\n",
"\n",
"+------+-----+-----+-----+-----+\n",
"|reside|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+------+-----+-----+-----+-----+\n",
"| 1| 121| 76| 106| 72|\n",
"| 2| 59| 63| 85| 65|\n",
"| 3| 32| 29| 39| 38|\n",
"| 4| 30| 31| 31| 28|\n",
"| 5| 16| 14| 13| 17|\n",
"| 6| 6| 4| 6| 13|\n",
"| 7| 2| 0| 0| 2|\n",
"| 8| 0| 0| 1| 1|\n",
"+------+-----+-----+-----+-----+\n",
"\n",
"+-------+-----+-----+-----+-----+\n",
"|custcat|cnt_A|cnt_B|cnt_C|cnt_D|\n",
"+-------+-----+-----+-----+-----+\n",
"| A| 266| 0| 0| 0|\n",
"| B| 0| 217| 0| 0|\n",
"| C| 0| 0| 281| 0|\n",
"| D| 0| 0| 0| 236|\n",
"+-------+-----+-----+-----+-----+\n",
"\n"
]
}
],
"source": [
"for c in cust_df.columns:\n",
" agg_df = (\n",
" cust_df.groupBy(c)\n",
" .agg(\n",
" *[F.count(F.when(F.col('custcat') == category, True)).alias('cnt_'+category) for category in ['A', 'B', 'C', 'D']]\n",
" )\n",
" .orderBy(c)\n",
" )\n",
" agg_df.show(10)"
]
},
{
"cell_type": "markdown",
"id": "c41b74e4",
"metadata": {
"papermill": {
"duration": 0.323555,
"end_time": "2022-04-18T16:14:41.647934",
"exception": false,
"start_time": "2022-04-18T16:14:41.324379",
"status": "completed"
},
"tags": []
},
"source": [
"## Data preprocessing"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "c4e5c164",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:42.356984Z",
"iopub.status.busy": "2022-04-18T16:14:42.355917Z",
"iopub.status.idle": "2022-04-18T16:14:42.371044Z",
"shell.execute_reply": "2022-04-18T16:14:42.370392Z",
"shell.execute_reply.started": "2022-04-18T14:31:42.216590Z"
},
"papermill": {
"duration": 0.342222,
"end_time": "2022-04-18T16:14:42.371188",
"exception": false,
"start_time": "2022-04-18T16:14:42.028966",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root\n",
" |-- region: integer (nullable = true)\n",
" |-- tenure: integer (nullable = true)\n",
" |-- age: integer (nullable = true)\n",
" |-- income: integer (nullable = true)\n",
" |-- marital: integer (nullable = true)\n",
" |-- ed: integer (nullable = true)\n",
" |-- employ: integer (nullable = true)\n",
" |-- retire: integer (nullable = true)\n",
" |-- gender: integer (nullable = true)\n",
" |-- custcat: string (nullable = true)\n",
"\n"
]
}
],
"source": [
"# Drop 'reside' and 'address' because of their unclear meaning. \n",
"cust_df = cust_df.drop('reside', 'address')\n",
"cust_df.printSchema()"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "dc87b185",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:43.023441Z",
"iopub.status.busy": "2022-04-18T16:14:43.022373Z",
"iopub.status.idle": "2022-04-18T16:14:43.024767Z",
"shell.execute_reply": "2022-04-18T16:14:43.025388Z",
"shell.execute_reply.started": "2022-04-18T14:31:42.236951Z"
},
"papermill": {
"duration": 0.331901,
"end_time": "2022-04-18T16:14:43.025565",
"exception": false,
"start_time": "2022-04-18T16:14:42.693664",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"categorical_cols = ['region', 'marital', 'ed', 'retire', 'gender']\n",
"numeric_cols = ['tenure', 'age', 'income', 'employ']"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "7faddeb5",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:43.667000Z",
"iopub.status.busy": "2022-04-18T16:14:43.665988Z",
"iopub.status.idle": "2022-04-18T16:14:45.153643Z",
"shell.execute_reply": "2022-04-18T16:14:45.152666Z",
"shell.execute_reply.started": "2022-04-18T14:31:42.243656Z"
},
"papermill": {
"duration": 1.80945,
"end_time": "2022-04-18T16:14:45.153875",
"exception": false,
"start_time": "2022-04-18T16:14:43.344425",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+------+---+------+-------+---+------+------+------+-------+----------+\n",
"|region|tenure|age|income|marital| ed|employ|retire|gender|custcat|region_idx|\n",
"+------+------+---+------+-------+---+------+------+------+-------+----------+\n",
"| 2| 13| 44| 64| 1| 4| 5| 0| 0| A| 1.0|\n",
"| 3| 11| 33| 136| 1| 5| 5| 0| 0| D| 0.0|\n",
"| 3| 68| 52| 116| 1| 1| 29| 0| 1| C| 0.0|\n",
"| 2| 33| 33| 33| 0| 2| 0| 0| 1| A| 1.0|\n",
"| 2| 23| 30| 30| 1| 1| 2| 0| 0| C| 1.0|\n",
"| 2| 41| 39| 78| 0| 2| 16| 0| 1| C| 1.0|\n",
"| 3| 45| 22| 19| 1| 2| 4| 0| 1| B| 0.0|\n",
"| 2| 38| 35| 76| 0| 2| 10| 0| 0| D| 1.0|\n",
"| 3| 45| 59| 166| 1| 4| 31| 0| 0| C| 0.0|\n",
"| 1| 68| 41| 72| 1| 1| 22| 0| 0| B| 2.0|\n",
"+------+------+---+------+-------+---+------+------+------+-------+----------+\n",
"only showing top 10 rows\n",
"\n",
"+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"|marital_idx|ed_idx|retire_idx|gender_idx| region_vec| marital_vec| ed_vec| retire_vec| gender_vec|\n",
"+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"| 1.0| 1.0| 0.0| 1.0|(3,[1],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 1.0| 4.0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[4],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 1.0| 3.0| 0.0| 0.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 0.0| 0.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 3.0| 0.0| 1.0|(3,[1],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 0.0| 0.0| 0.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 0.0| 0.0| 1.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 1.0| 1.0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 1.0| 3.0| 0.0| 1.0|(3,[2],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"# One hot encoding\n",
"from pyspark.ml.feature import StringIndexer, OneHotEncoder\n",
"\n",
"\n",
"def modify_cols(cols, prefix='', postfix=''):\n",
" return [prefix + col + postfix for col in cols]\n",
"\n",
"indexer = StringIndexer(inputCols=categorical_cols, outputCols=modify_cols(categorical_cols, postfix='_idx'))\n",
"indexed_df = indexer.fit(cust_df).transform(cust_df)\n",
"\n",
"encoder = OneHotEncoder(inputCols=modify_cols(categorical_cols, postfix='_idx'), \n",
" outputCols=modify_cols(categorical_cols, postfix='_vec'),\n",
" dropLast=False)\n",
"encoded_df = encoder.fit(indexed_df).transform(indexed_df)\n",
"\n",
"show_split(encoded_df, 11, 10)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "c4aa7162",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:45.923678Z",
"iopub.status.busy": "2022-04-18T16:14:45.922999Z",
"iopub.status.idle": "2022-04-18T16:14:46.929276Z",
"shell.execute_reply": "2022-04-18T16:14:46.927533Z",
"shell.execute_reply.started": "2022-04-18T14:31:43.638305Z"
},
"papermill": {
"duration": 1.354128,
"end_time": "2022-04-18T16:14:46.929555",
"exception": false,
"start_time": "2022-04-18T16:14:45.575427",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+------+---+------+-------+---+------+------+\n",
"|region|tenure|age|income|marital| ed|employ|retire|\n",
"+------+------+---+------+-------+---+------+------+\n",
"| 2| 13| 44| 64| 1| 4| 5| 0|\n",
"| 3| 11| 33| 136| 1| 5| 5| 0|\n",
"| 3| 68| 52| 116| 1| 1| 29| 0|\n",
"| 2| 33| 33| 33| 0| 2| 0| 0|\n",
"| 2| 23| 30| 30| 1| 1| 2| 0|\n",
"| 2| 41| 39| 78| 0| 2| 16| 0|\n",
"| 3| 45| 22| 19| 1| 2| 4| 0|\n",
"| 2| 38| 35| 76| 0| 2| 10| 0|\n",
"| 3| 45| 59| 166| 1| 4| 31| 0|\n",
"| 1| 68| 41| 72| 1| 1| 22| 0|\n",
"+------+------+---+------+-------+---+------+------+\n",
"only showing top 10 rows\n",
"\n",
"+------+-------+----------+-----------+------+----------+----------+-------------+\n",
"|gender|custcat|region_idx|marital_idx|ed_idx|retire_idx|gender_idx| region_vec|\n",
"+------+-------+----------+-----------+------+----------+----------+-------------+\n",
"| 0| A| 1.0| 1.0| 1.0| 0.0| 1.0|(3,[1],[1.0])|\n",
"| 0| D| 0.0| 1.0| 4.0| 0.0| 1.0|(3,[0],[1.0])|\n",
"| 1| C| 0.0| 1.0| 3.0| 0.0| 0.0|(3,[0],[1.0])|\n",
"| 1| A| 1.0| 0.0| 0.0| 0.0| 0.0|(3,[1],[1.0])|\n",
"| 0| C| 1.0| 1.0| 3.0| 0.0| 1.0|(3,[1],[1.0])|\n",
"| 1| C| 1.0| 0.0| 0.0| 0.0| 0.0|(3,[1],[1.0])|\n",
"| 1| B| 0.0| 1.0| 0.0| 0.0| 0.0|(3,[0],[1.0])|\n",
"| 0| D| 1.0| 0.0| 0.0| 0.0| 1.0|(3,[1],[1.0])|\n",
"| 0| C| 0.0| 1.0| 1.0| 0.0| 1.0|(3,[0],[1.0])|\n",
"| 0| B| 2.0| 1.0| 3.0| 0.0| 1.0|(3,[2],[1.0])|\n",
"+------+-------+----------+-----------+------+----------+----------+-------------+\n",
"only showing top 10 rows\n",
"\n",
"+-------------+-------------+-------------+-------------+--------------------+--------------------+\n",
"| marital_vec| ed_vec| retire_vec| gender_vec| numeric| numericScaled|\n",
"+-------------+-------------+-------------+-------------+--------------------+--------------------+\n",
"|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|[13.0,44.0,64.0,5.0]|[0.16901408450704...|\n",
"|(2,[1],[1.0])|(5,[4],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|[11.0,33.0,136.0,...|[0.14084507042253...|\n",
"|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|[68.0,52.0,116.0,...|[0.94366197183098...|\n",
"|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|[33.0,33.0,33.0,0.0]|[0.45070422535211...|\n",
"|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|[23.0,30.0,30.0,2.0]|[0.30985915492957...|\n",
"|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|[41.0,39.0,78.0,1...|[0.56338028169014...|\n",
"|(2,[1],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|[45.0,22.0,19.0,4.0]|[0.61971830985915...|\n",
"|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|[38.0,35.0,76.0,1...|[0.52112676056338...|\n",
"|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|[45.0,59.0,166.0,...|[0.61971830985915...|\n",
"|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|[68.0,41.0,72.0,2...|[0.94366197183098...|\n",
"+-------------+-------------+-------------+-------------+--------------------+--------------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"# Minmax scaling\n",
"from pyspark.ml.feature import VectorAssembler, MinMaxScaler\n",
"\n",
"\n",
"numeric_assembler = VectorAssembler(inputCols=numeric_cols, outputCol='numeric')\n",
"assembled_df = numeric_assembler.transform(encoded_df)\n",
"\n",
"scaler = MinMaxScaler(inputCol='numeric', outputCol='numericScaled')\n",
"scaled_df = scaler.fit(assembled_df).transform(assembled_df)\n",
"show_split(scaled_df, 8, 10)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "5637a50b",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:47.886908Z",
"iopub.status.busy": "2022-04-18T16:14:47.884252Z",
"iopub.status.idle": "2022-04-18T16:14:48.478609Z",
"shell.execute_reply": "2022-04-18T16:14:48.477673Z",
"shell.execute_reply.started": "2022-04-18T14:31:44.487799Z"
},
"papermill": {
"duration": 0.978968,
"end_time": "2022-04-18T16:14:48.478819",
"exception": false,
"start_time": "2022-04-18T16:14:47.499851",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"|region|tenure|age|income|marital| ed|employ|retire|gender|custcat|\n",
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"| 2| 13| 44| 64| 1| 4| 5| 0| 0| A|\n",
"| 3| 11| 33| 136| 1| 5| 5| 0| 0| D|\n",
"| 3| 68| 52| 116| 1| 1| 29| 0| 1| C|\n",
"| 2| 33| 33| 33| 0| 2| 0| 0| 1| A|\n",
"| 2| 23| 30| 30| 1| 1| 2| 0| 0| C|\n",
"| 2| 41| 39| 78| 0| 2| 16| 0| 1| C|\n",
"| 3| 45| 22| 19| 1| 2| 4| 0| 1| B|\n",
"| 2| 38| 35| 76| 0| 2| 10| 0| 0| D|\n",
"| 3| 45| 59| 166| 1| 4| 31| 0| 0| C|\n",
"| 1| 68| 41| 72| 1| 1| 22| 0| 0| B|\n",
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"only showing top 10 rows\n",
"\n",
"+----------+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"|region_idx|marital_idx|ed_idx|retire_idx|gender_idx| region_vec| marital_vec| ed_vec| retire_vec| gender_vec|\n",
"+----------+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"| 1.0| 1.0| 1.0| 0.0| 1.0|(3,[1],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 0.0| 1.0| 4.0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[4],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 0.0| 1.0| 3.0| 0.0| 0.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 1.0| 3.0| 0.0| 1.0|(3,[1],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 0.0| 0.0| 0.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 1.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 0.0| 1.0| 1.0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 2.0| 1.0| 3.0| 0.0| 1.0|(3,[2],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"+----------+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"only showing top 10 rows\n",
"\n",
"+--------------------+--------------------+--------------------+\n",
"| numeric| numericScaled| feature|\n",
"+--------------------+--------------------+--------------------+\n",
"|[13.0,44.0,64.0,5.0]|[0.16901408450704...|(18,[1,4,6,10,13,...|\n",
"|[11.0,33.0,136.0,...|[0.14084507042253...|(18,[0,4,9,10,13,...|\n",
"|[68.0,52.0,116.0,...|[0.94366197183098...|(18,[0,4,8,10,12,...|\n",
"|[33.0,33.0,33.0,0.0]|[0.45070422535211...|(18,[1,3,5,10,12,...|\n",
"|[23.0,30.0,30.0,2.0]|[0.30985915492957...|(18,[1,4,8,10,13,...|\n",
"|[41.0,39.0,78.0,1...|[0.56338028169014...|(18,[1,3,5,10,12,...|\n",
"|[45.0,22.0,19.0,4.0]|[0.61971830985915...|(18,[0,4,5,10,12,...|\n",
"|[38.0,35.0,76.0,1...|[0.52112676056338...|(18,[1,3,5,10,13,...|\n",
"|[45.0,59.0,166.0,...|[0.61971830985915...|(18,[0,4,6,10,13,...|\n",
"|[68.0,41.0,72.0,2...|[0.94366197183098...|(18,[2,4,8,10,13,...|\n",
"+--------------------+--------------------+--------------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"# Gather essential cols to build the features\n",
"feature_cols = ['region_vec', 'marital_vec', 'ed_vec', 'retire_vec', 'gender_vec', 'numericScaled']\n",
"features_assembler = VectorAssembler(inputCols=feature_cols, outputCol='feature')\n",
"features_df = features_assembler.transform(scaled_df)\n",
"show_split(features_df, 10, 10)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "546f2131",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:49.201929Z",
"iopub.status.busy": "2022-04-18T16:14:49.198192Z",
"iopub.status.idle": "2022-04-18T16:14:49.875920Z",
"shell.execute_reply": "2022-04-18T16:14:49.875231Z",
"shell.execute_reply.started": "2022-04-18T14:31:45.087771Z"
},
"papermill": {
"duration": 1.00758,
"end_time": "2022-04-18T16:14:49.876088",
"exception": false,
"start_time": "2022-04-18T16:14:48.868508",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"|region|tenure|age|income|marital| ed|employ|retire|gender|custcat|\n",
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"| 2| 13| 44| 64| 1| 4| 5| 0| 0| A|\n",
"| 3| 11| 33| 136| 1| 5| 5| 0| 0| D|\n",
"| 3| 68| 52| 116| 1| 1| 29| 0| 1| C|\n",
"| 2| 33| 33| 33| 0| 2| 0| 0| 1| A|\n",
"| 2| 23| 30| 30| 1| 1| 2| 0| 0| C|\n",
"| 2| 41| 39| 78| 0| 2| 16| 0| 1| C|\n",
"| 3| 45| 22| 19| 1| 2| 4| 0| 1| B|\n",
"| 2| 38| 35| 76| 0| 2| 10| 0| 0| D|\n",
"| 3| 45| 59| 166| 1| 4| 31| 0| 0| C|\n",
"| 1| 68| 41| 72| 1| 1| 22| 0| 0| B|\n",
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"only showing top 10 rows\n",
"\n",
"+----------+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"|region_idx|marital_idx|ed_idx|retire_idx|gender_idx| region_vec| marital_vec| ed_vec| retire_vec| gender_vec|\n",
"+----------+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"| 1.0| 1.0| 1.0| 0.0| 1.0|(3,[1],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 0.0| 1.0| 4.0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[4],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 0.0| 1.0| 3.0| 0.0| 0.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 1.0| 3.0| 0.0| 1.0|(3,[1],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 0.0| 0.0| 0.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 1.0|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 0.0| 1.0| 1.0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"| 2.0| 1.0| 3.0| 0.0| 1.0|(3,[2],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|(2,[1],[1.0])|\n",
"+----------+-----------+------+----------+----------+-------------+-------------+-------------+-------------+-------------+\n",
"only showing top 10 rows\n",
"\n",
"+--------------------+--------------------+--------------------+-----------+-------------+\n",
"| numeric| numericScaled| feature|custcat_idx| custcat_vec|\n",
"+--------------------+--------------------+--------------------+-----------+-------------+\n",
"|[13.0,44.0,64.0,5.0]|[0.16901408450704...|(18,[1,4,6,10,13,...| 1.0|(4,[1],[1.0])|\n",
"|[11.0,33.0,136.0,...|[0.14084507042253...|(18,[0,4,9,10,13,...| 2.0|(4,[2],[1.0])|\n",
"|[68.0,52.0,116.0,...|[0.94366197183098...|(18,[0,4,8,10,12,...| 0.0|(4,[0],[1.0])|\n",
"|[33.0,33.0,33.0,0.0]|[0.45070422535211...|(18,[1,3,5,10,12,...| 1.0|(4,[1],[1.0])|\n",
"|[23.0,30.0,30.0,2.0]|[0.30985915492957...|(18,[1,4,8,10,13,...| 0.0|(4,[0],[1.0])|\n",
"|[41.0,39.0,78.0,1...|[0.56338028169014...|(18,[1,3,5,10,12,...| 0.0|(4,[0],[1.0])|\n",
"|[45.0,22.0,19.0,4.0]|[0.61971830985915...|(18,[0,4,5,10,12,...| 3.0|(4,[3],[1.0])|\n",
"|[38.0,35.0,76.0,1...|[0.52112676056338...|(18,[1,3,5,10,13,...| 2.0|(4,[2],[1.0])|\n",
"|[45.0,59.0,166.0,...|[0.61971830985915...|(18,[0,4,6,10,13,...| 0.0|(4,[0],[1.0])|\n",
"|[68.0,41.0,72.0,2...|[0.94366197183098...|(18,[2,4,8,10,13,...| 3.0|(4,[3],[1.0])|\n",
"+--------------------+--------------------+--------------------+-----------+-------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"label_indexer = StringIndexer(inputCol='custcat', outputCol='custcat_idx')\n",
"train_df = label_indexer.fit(features_df).transform(features_df)\n",
"label_encoder = OneHotEncoder(inputCol='custcat_idx', outputCol='custcat_vec', dropLast=False)\n",
"train_df = label_encoder.fit(train_df).transform(train_df)\n",
"show_split(train_df, 10, 10)"
]
},
{
"cell_type": "markdown",
"id": "1e5ff53a",
"metadata": {
"papermill": {
"duration": 0.318151,
"end_time": "2022-04-18T16:14:50.519819",
"exception": false,
"start_time": "2022-04-18T16:14:50.201668",
"status": "completed"
},
"tags": []
},
"source": [
"## Machine learning pipeline"
]
},
{
"cell_type": "markdown",
"id": "e21c0aec",
"metadata": {
"papermill": {
"duration": 0.331986,
"end_time": "2022-04-18T16:14:51.177806",
"exception": false,
"start_time": "2022-04-18T16:14:50.845820",
"status": "completed"
},
"tags": []
},
"source": [
"We can combine all transformers and estimators into a single pipeline"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "a37657ce",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:51.821500Z",
"iopub.status.busy": "2022-04-18T16:14:51.819621Z",
"iopub.status.idle": "2022-04-18T16:14:53.353205Z",
"shell.execute_reply": "2022-04-18T16:14:53.352302Z",
"shell.execute_reply.started": "2022-04-18T14:31:45.775228Z"
},
"papermill": {
"duration": 1.85841,
"end_time": "2022-04-18T16:14:53.353436",
"exception": false,
"start_time": "2022-04-18T16:14:51.495026",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"|region|tenure|age|income|marital| ed|employ|retire|gender|custcat|\n",
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"| 2| 13| 44| 64| 1| 4| 5| 0| 0| A|\n",
"| 3| 11| 33| 136| 1| 5| 5| 0| 0| D|\n",
"| 3| 68| 52| 116| 1| 1| 29| 0| 1| C|\n",
"| 2| 33| 33| 33| 0| 2| 0| 0| 1| A|\n",
"| 2| 23| 30| 30| 1| 1| 2| 0| 0| C|\n",
"| 2| 41| 39| 78| 0| 2| 16| 0| 1| C|\n",
"| 3| 45| 22| 19| 1| 2| 4| 0| 1| B|\n",
"| 2| 38| 35| 76| 0| 2| 10| 0| 0| D|\n",
"| 3| 45| 59| 166| 1| 4| 31| 0| 0| C|\n",
"| 1| 68| 41| 72| 1| 1| 22| 0| 0| B|\n",
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"only showing top 10 rows\n",
"\n",
"+----------+-----------+------+----------+----------+--------------------+-------------+-------------+-------------+-------------+\n",
"|region_idx|marital_idx|ed_idx|retire_idx|gender_idx| numeric| region_vec| marital_vec| ed_vec| retire_vec|\n",
"+----------+-----------+------+----------+----------+--------------------+-------------+-------------+-------------+-------------+\n",
"| 1.0| 1.0| 1.0| 0.0| 1.0|[13.0,44.0,64.0,5.0]|(3,[1],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 4.0| 0.0| 1.0|[11.0,33.0,136.0,...|(3,[0],[1.0])|(2,[1],[1.0])|(5,[4],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 3.0| 0.0| 0.0|[68.0,52.0,116.0,...|(3,[0],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 0.0|[33.0,33.0,33.0,0.0]|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 1.0| 3.0| 0.0| 1.0|[23.0,30.0,30.0,2.0]|(3,[1],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 0.0|[41.0,39.0,78.0,1...|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 0.0| 0.0| 0.0|[45.0,22.0,19.0,4.0]|(3,[0],[1.0])|(2,[1],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 1.0|[38.0,35.0,76.0,1...|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 1.0| 0.0| 1.0|[45.0,59.0,166.0,...|(3,[0],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|\n",
"| 2.0| 1.0| 3.0| 0.0| 1.0|[68.0,41.0,72.0,2...|(3,[2],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|\n",
"+----------+-----------+------+----------+----------+--------------------+-------------+-------------+-------------+-------------+\n",
"only showing top 10 rows\n",
"\n",
"+-------------+--------------------+--------------------+-----------+-------------+\n",
"| gender_vec| numericScaled| feature|custcat_idx| custcat_vec|\n",
"+-------------+--------------------+--------------------+-----------+-------------+\n",
"|(2,[1],[1.0])|[0.16901408450704...|(18,[1,4,6,10,13,...| 1.0|(4,[1],[1.0])|\n",
"|(2,[1],[1.0])|[0.14084507042253...|(18,[0,4,9,10,13,...| 2.0|(4,[2],[1.0])|\n",
"|(2,[0],[1.0])|[0.94366197183098...|(18,[0,4,8,10,12,...| 0.0|(4,[0],[1.0])|\n",
"|(2,[0],[1.0])|[0.45070422535211...|(18,[1,3,5,10,12,...| 1.0|(4,[1],[1.0])|\n",
"|(2,[1],[1.0])|[0.30985915492957...|(18,[1,4,8,10,13,...| 0.0|(4,[0],[1.0])|\n",
"|(2,[0],[1.0])|[0.56338028169014...|(18,[1,3,5,10,12,...| 0.0|(4,[0],[1.0])|\n",
"|(2,[0],[1.0])|[0.61971830985915...|(18,[0,4,5,10,12,...| 3.0|(4,[3],[1.0])|\n",
"|(2,[1],[1.0])|[0.52112676056338...|(18,[1,3,5,10,13,...| 2.0|(4,[2],[1.0])|\n",
"|(2,[1],[1.0])|[0.61971830985915...|(18,[0,4,6,10,13,...| 0.0|(4,[0],[1.0])|\n",
"|(2,[1],[1.0])|[0.94366197183098...|(18,[2,4,8,10,13,...| 3.0|(4,[3],[1.0])|\n",
"+-------------+--------------------+--------------------+-----------+-------------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"from pyspark.ml import Pipeline\n",
"\n",
"ml_pipeline = Pipeline(stages=[indexer, numeric_assembler, encoder, scaler, features_assembler, label_indexer, label_encoder])\n",
"model = ml_pipeline.fit(cust_df)\n",
"train_df = model.transform(cust_df)\n",
"show_split(train_df, 10, 10)"
]
},
{
"cell_type": "markdown",
"id": "99e4bc63",
"metadata": {
"papermill": {
"duration": 0.321398,
"end_time": "2022-04-18T16:14:54.014785",
"exception": false,
"start_time": "2022-04-18T16:14:53.693387",
"status": "completed"
},
"tags": []
},
"source": [
"## Training"
]
},
{
"cell_type": "markdown",
"id": "0721f396",
"metadata": {
"papermill": {
"duration": 0.319905,
"end_time": "2022-04-18T16:14:54.656939",
"exception": false,
"start_time": "2022-04-18T16:14:54.337034",
"status": "completed"
},
"tags": []
},
"source": [
"Machine learning model is just an estimator in PySpark."
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "052d977e",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:14:55.318252Z",
"iopub.status.busy": "2022-04-18T16:14:55.317436Z",
"iopub.status.idle": "2022-04-18T16:14:59.798448Z",
"shell.execute_reply": "2022-04-18T16:14:59.797320Z",
"shell.execute_reply.started": "2022-04-18T14:31:47.124457Z"
},
"papermill": {
"duration": 4.812172,
"end_time": "2022-04-18T16:14:59.798793",
"exception": false,
"start_time": "2022-04-18T16:14:54.986621",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"|region|tenure|age|income|marital| ed|employ|retire|gender|custcat|\n",
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"| 2| 13| 44| 64| 1| 4| 5| 0| 0| A|\n",
"| 3| 11| 33| 136| 1| 5| 5| 0| 0| D|\n",
"| 3| 68| 52| 116| 1| 1| 29| 0| 1| C|\n",
"| 2| 33| 33| 33| 0| 2| 0| 0| 1| A|\n",
"| 2| 23| 30| 30| 1| 1| 2| 0| 0| C|\n",
"| 2| 41| 39| 78| 0| 2| 16| 0| 1| C|\n",
"| 3| 45| 22| 19| 1| 2| 4| 0| 1| B|\n",
"| 2| 38| 35| 76| 0| 2| 10| 0| 0| D|\n",
"| 3| 45| 59| 166| 1| 4| 31| 0| 0| C|\n",
"| 1| 68| 41| 72| 1| 1| 22| 0| 0| B|\n",
"+------+------+---+------+-------+---+------+------+------+-------+\n",
"only showing top 10 rows\n",
"\n",
"+----------+-----------+------+----------+----------+--------------------+-------------+-------------+-------------+-------------+\n",
"|region_idx|marital_idx|ed_idx|retire_idx|gender_idx| numeric| region_vec| marital_vec| ed_vec| retire_vec|\n",
"+----------+-----------+------+----------+----------+--------------------+-------------+-------------+-------------+-------------+\n",
"| 1.0| 1.0| 1.0| 0.0| 1.0|[13.0,44.0,64.0,5.0]|(3,[1],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 4.0| 0.0| 1.0|[11.0,33.0,136.0,...|(3,[0],[1.0])|(2,[1],[1.0])|(5,[4],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 3.0| 0.0| 0.0|[68.0,52.0,116.0,...|(3,[0],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 0.0|[33.0,33.0,33.0,0.0]|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 1.0| 3.0| 0.0| 1.0|[23.0,30.0,30.0,2.0]|(3,[1],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 0.0|[41.0,39.0,78.0,1...|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 0.0| 0.0| 0.0|[45.0,22.0,19.0,4.0]|(3,[0],[1.0])|(2,[1],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|\n",
"| 1.0| 0.0| 0.0| 0.0| 1.0|[38.0,35.0,76.0,1...|(3,[1],[1.0])|(2,[0],[1.0])|(5,[0],[1.0])|(2,[0],[1.0])|\n",
"| 0.0| 1.0| 1.0| 0.0| 1.0|[45.0,59.0,166.0,...|(3,[0],[1.0])|(2,[1],[1.0])|(5,[1],[1.0])|(2,[0],[1.0])|\n",
"| 2.0| 1.0| 3.0| 0.0| 1.0|[68.0,41.0,72.0,2...|(3,[2],[1.0])|(2,[1],[1.0])|(5,[3],[1.0])|(2,[0],[1.0])|\n",
"+----------+-----------+------+----------+----------+--------------------+-------------+-------------+-------------+-------------+\n",
"only showing top 10 rows\n",
"\n",
"+-------------+--------------------+--------------------+-----------+-------------+--------------------+--------------------+----------+\n",
"| gender_vec| numericScaled| feature|custcat_idx| custcat_vec| rawPrediction| probability|prediction|\n",
"+-------------+--------------------+--------------------+-----------+-------------+--------------------+--------------------+----------+\n",
"|(2,[1],[1.0])|[0.16901408450704...|(18,[1,4,6,10,13,...| 1.0|(4,[1],[1.0])|[2.69883097174239...|[0.13494154858711...| 2.0|\n",
"|(2,[1],[1.0])|[0.14084507042253...|(18,[0,4,9,10,13,...| 2.0|(4,[2],[1.0])|[1.25867963894834...|[0.06293398194741...| 2.0|\n",
"|(2,[0],[1.0])|[0.94366197183098...|(18,[0,4,8,10,12,...| 0.0|(4,[0],[1.0])|[9.71293038529637...|[0.48564651926481...| 0.0|\n",
"|(2,[0],[1.0])|[0.45070422535211...|(18,[1,3,5,10,12,...| 1.0|(4,[1],[1.0])|[5.48212592252259...|[0.27410629612612...| 1.0|\n",
"|(2,[1],[1.0])|[0.30985915492957...|(18,[1,4,8,10,13,...| 0.0|(4,[0],[1.0])|[9.14844309573381...|[0.45742215478669...| 0.0|\n",
"|(2,[0],[1.0])|[0.56338028169014...|(18,[1,3,5,10,12,...| 0.0|(4,[0],[1.0])|[6.28417748455353...|[0.31420887422767...| 0.0|\n",
"|(2,[0],[1.0])|[0.61971830985915...|(18,[0,4,5,10,12,...| 3.0|(4,[3],[1.0])|[5.63792417136418...|[0.28189620856820...| 1.0|\n",
"|(2,[1],[1.0])|[0.52112676056338...|(18,[1,3,5,10,13,...| 2.0|(4,[2],[1.0])|[5.34298824100371...|[0.26714941205018...| 1.0|\n",
"|(2,[1],[1.0])|[0.61971830985915...|(18,[0,4,6,10,13,...| 0.0|(4,[0],[1.0])|[5.89253225861025...|[0.29462661293051...| 2.0|\n",
"|(2,[1],[1.0])|[0.94366197183098...|(18,[2,4,8,10,13,...| 3.0|(4,[3],[1.0])|[9.06482369427615...|[0.45324118471380...| 0.0|\n",
"+-------------+--------------------+--------------------+-----------+-------------+--------------------+--------------------+----------+\n",
"only showing top 10 rows\n",
"\n"
]
}
],
"source": [
"from pyspark.ml.classification import RandomForestClassifier\n",
"\n",
"lr = RandomForestClassifier(featuresCol=\"feature\", labelCol=\"custcat_idx\", predictionCol=\"prediction\")\n",
"ml_pipeline = Pipeline(stages=[indexer, numeric_assembler, encoder, scaler, features_assembler, \n",
" label_indexer, label_encoder, lr])\n",
"ml_model = ml_pipeline.fit(cust_df)\n",
"result_df = ml_model.transform(cust_df)\n",
"show_split(result_df, 10, 10)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "f190a7f3",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:00.485043Z",
"iopub.status.busy": "2022-04-18T16:15:00.484348Z",
"iopub.status.idle": "2022-04-18T16:15:01.563000Z",
"shell.execute_reply": "2022-04-18T16:15:01.562449Z",
"shell.execute_reply.started": "2022-04-18T14:31:51.446198Z"
},
"papermill": {
"duration": 1.406319,
"end_time": "2022-04-18T16:15:01.563142",
"exception": false,
"start_time": "2022-04-18T16:15:00.156823",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1800x900 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_columns_dist(result_df, ['custcat_idx', 'prediction'], 1, 2).show()"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "2b3c2176",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:02.247826Z",
"iopub.status.busy": "2022-04-18T16:15:02.245614Z",
"iopub.status.idle": "2022-04-18T16:15:03.218265Z",
"shell.execute_reply": "2022-04-18T16:15:03.217304Z",
"shell.execute_reply.started": "2022-04-18T14:31:52.390144Z"
},
"papermill": {
"duration": 1.308965,
"end_time": "2022-04-18T16:15:03.218502",
"exception": false,
"start_time": "2022-04-18T16:15:01.909537",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-----------+---+---+---+---+\n",
"|custcat_idx|0.0|1.0|2.0|3.0|\n",
"+-----------+---+---+---+---+\n",
"| 0.0|153| 87| 28| 13|\n",
"| 1.0| 48|184| 28| 6|\n",
"| 2.0| 54| 53|117| 12|\n",
"| 3.0| 67| 50| 46| 54|\n",
"+-----------+---+---+---+---+\n",
"\n"
]
}
],
"source": [
"# evaluate\n",
"confusion_matrix = result_df.groupBy('custcat_idx').pivot('prediction').count().orderBy('custcat_idx')\n",
"confusion_matrix.show()"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "baadaeb2",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:03.879563Z",
"iopub.status.busy": "2022-04-18T16:15:03.878781Z",
"iopub.status.idle": "2022-04-18T16:15:04.370169Z",
"shell.execute_reply": "2022-04-18T16:15:04.369583Z",
"shell.execute_reply.started": "2022-04-18T14:31:53.280888Z"
},
"papermill": {
"duration": 0.819848,
"end_time": "2022-04-18T16:15:04.370329",
"exception": false,
"start_time": "2022-04-18T16:15:03.550481",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Precision\n",
"0.0 0.475155\n",
"1.0 0.491979\n",
"2.0 0.534247\n",
"3.0 0.635294\n",
"dtype: float64\n",
"Recall\n",
"custcat_idx\n",
"0.0 0.544484\n",
"1.0 0.691729\n",
"2.0 0.495763\n",
"3.0 0.248848\n",
"dtype: float64\n"
]
}
],
"source": [
"# We can calculate precision and recall from confusion matrix or directly with pyspark\n",
"cmat = confusion_matrix.toPandas()\n",
"cmat.index = cmat['custcat_idx']\n",
"cmat = cmat.drop('custcat_idx', axis=1)\n",
"correct = np.diagonal(cmat)\n",
"precision = correct / np.sum(cmat, axis=0)\n",
"recall = correct / np.sum(cmat, axis=1)\n",
"print('Precision')\n",
"print(precision)\n",
"print('Recall')\n",
"print(recall)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "c3c0cfaf",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:05.033112Z",
"iopub.status.busy": "2022-04-18T16:15:05.032362Z",
"iopub.status.idle": "2022-04-18T16:15:07.135943Z",
"shell.execute_reply": "2022-04-18T16:15:07.135082Z",
"shell.execute_reply.started": "2022-04-18T14:31:53.781753Z"
},
"papermill": {
"duration": 2.43491,
"end_time": "2022-04-18T16:15:07.136188",
"exception": false,
"start_time": "2022-04-18T16:15:04.701278",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Precision\n",
"0 0.4751552795031056\n",
"1 0.4919786096256685\n",
"2 0.5342465753424658\n",
"3 0.6352941176470588\n",
"Recall\n",
"0 0.5444839857651246\n",
"1 0.6917293233082706\n",
"2 0.4957627118644068\n",
"3 0.2488479262672811\n"
]
}
],
"source": [
"# This code is legacy\n",
"from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
"\n",
"evaluator = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='custcat_idx')\n",
"\n",
"print('Precision')\n",
"for i in range(4):\n",
" print(i, evaluator.evaluate(result_df, {evaluator.metricName: \"precisionByLabel\", evaluator.metricLabel: i}))\n",
" \n",
"print('Recall')\n",
"for i in range(4):\n",
" print(i, evaluator.evaluate(result_df, {evaluator.metricName: \"recallByLabel\", evaluator.metricLabel: i}))"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "8065caf9",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:07.823799Z",
"iopub.status.busy": "2022-04-18T16:15:07.822971Z",
"iopub.status.idle": "2022-04-18T16:15:08.031673Z",
"shell.execute_reply": "2022-04-18T16:15:08.030639Z",
"shell.execute_reply.started": "2022-04-18T14:31:55.494138Z"
},
"papermill": {
"duration": 0.542437,
"end_time": "2022-04-18T16:15:08.031948",
"exception": false,
"start_time": "2022-04-18T16:15:07.489511",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"0.49452109250342424"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluator.evaluate(result_df, {evaluator.metricName: \"f1\"})"
]
},
{
"cell_type": "markdown",
"id": "f3aeff89",
"metadata": {
"papermill": {
"duration": 0.331308,
"end_time": "2022-04-18T16:15:08.772817",
"exception": false,
"start_time": "2022-04-18T16:15:08.441509",
"status": "completed"
},
"tags": []
},
"source": [
"## Validation scheme"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "13fdb1e0",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:09.445137Z",
"iopub.status.busy": "2022-04-18T16:15:09.444031Z",
"iopub.status.idle": "2022-04-18T16:15:09.447250Z",
"shell.execute_reply": "2022-04-18T16:15:09.447753Z",
"shell.execute_reply.started": "2022-04-18T14:31:55.696320Z"
},
"papermill": {
"duration": 0.336593,
"end_time": "2022-04-18T16:15:09.447959",
"exception": false,
"start_time": "2022-04-18T16:15:09.111366",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"def accuracy(model, df):\n",
" df = model.transform(df)\n",
" return evaluator.evaluate(result_df, {evaluator.metricName: \"accuracy\"}) "
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "d07c56aa",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:10.141968Z",
"iopub.status.busy": "2022-04-18T16:15:10.140947Z",
"iopub.status.idle": "2022-04-18T16:15:15.119456Z",
"shell.execute_reply": "2022-04-18T16:15:15.118697Z",
"shell.execute_reply.started": "2022-04-18T14:31:55.704443Z"
},
"papermill": {
"duration": 5.316554,
"end_time": "2022-04-18T16:15:15.119660",
"exception": false,
"start_time": "2022-04-18T16:15:09.803106",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"0.508"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# train test split\n",
"train_df, test_df = cust_df.randomSplit([0.7, 0.3], seed=19032000)\n",
"model = ml_pipeline.fit(train_df)\n",
"accuracy(model, test_df)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "bfe349b3",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:15.814787Z",
"iopub.status.busy": "2022-04-18T16:15:15.814085Z",
"iopub.status.idle": "2022-04-18T16:15:27.219904Z",
"shell.execute_reply": "2022-04-18T16:15:27.220768Z",
"shell.execute_reply.started": "2022-04-18T14:31:58.017022Z"
},
"papermill": {
"duration": 11.749372,
"end_time": "2022-04-18T16:15:27.221109",
"exception": false,
"start_time": "2022-04-18T16:15:15.471737",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"[0.3707128486580884]"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Cross validation\n",
"# We can use cross validation to do grid search with built in params\n",
"from pyspark.ml.tuning import CrossValidator, ParamGridBuilder\n",
"\n",
"params_grid = ParamGridBuilder().build()\n",
"evaluator = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='custcat_idx', metricName=\"accuracy\")\n",
"cv = CrossValidator(estimator=ml_pipeline, estimatorParamMaps=params_grid, evaluator=evaluator, numFolds=4, seed=19032000)\n",
"\n",
"cv_model = cv.fit(cust_df)\n",
"cv_model.avgMetrics"
]
},
{
"cell_type": "markdown",
"id": "f5971936",
"metadata": {
"papermill": {
"duration": 0.336217,
"end_time": "2022-04-18T16:15:27.894379",
"exception": false,
"start_time": "2022-04-18T16:15:27.558162",
"status": "completed"
},
"tags": []
},
"source": [
"## Custom transformer/estimator"
]
},
{
"cell_type": "markdown",
"id": "31738d44",
"metadata": {
"papermill": {
"duration": 0.330387,
"end_time": "2022-04-18T16:15:28.561130",
"exception": false,
"start_time": "2022-04-18T16:15:28.230743",
"status": "completed"
},
"tags": []
},
"source": [
"Behind transformers and estimators, PySpark has the concept of `Param`/`Params`, self-documenting attributes that govern how a transformer or estimator behaves. When creating custom transformers/estimators, we create their `Param` first, and then use them in `transform()`-/`fit()`-like instance attributes. PySpark provides standard Params for frequent use cases in the `pyspark.ml.param.shared` module.\n",
"\n",
"Below is example of a custom transformer and estimator with detailed explanation in comment."
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "7949f689",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:29.237555Z",
"iopub.status.busy": "2022-04-18T16:15:29.236489Z",
"iopub.status.idle": "2022-04-18T16:15:30.011066Z",
"shell.execute_reply": "2022-04-18T16:15:30.011609Z",
"shell.execute_reply.started": "2022-04-18T14:32:08.394183Z"
},
"papermill": {
"duration": 1.116286,
"end_time": "2022-04-18T16:15:30.011798",
"exception": false,
"start_time": "2022-04-18T16:15:28.895512",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----+----+-----+\n",
"| a| b| c|\n",
"+----+----+-----+\n",
"| 1| 4.0| GFG1|\n",
"| 5| 1.2| GFC2|\n",
"| 5|6.25| null|\n",
"| 8|7.11|CLC22|\n",
"|null|0.22| SSv4|\n",
"|null|3.48| MNM0|\n",
"| 0|null| TNT2|\n",
"+----+----+-----+\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"# We will use a meaningless and simple data to test our transformer and estimator\n",
"temp_df = df = spark.createDataFrame([\n",
" (1, 4., 'GFG1'),\n",
" (5, 1.2, 'GFC2'),\n",
" (5, 6.25, None),\n",
" (8, 7.11, 'CLC22'),\n",
" (None, 0.22, 'SSv4'),\n",
" (None, 3.48, 'MNM0'),\n",
" (0, None, 'TNT2'),\n",
"], schema='a long, b double, c string')\n",
"\n",
"temp_df.show()"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "2fdfcd8f",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:30.702079Z",
"iopub.status.busy": "2022-04-18T16:15:30.701177Z",
"iopub.status.idle": "2022-04-18T16:15:30.703946Z",
"shell.execute_reply": "2022-04-18T16:15:30.704539Z",
"shell.execute_reply.started": "2022-04-18T15:39:35.267808Z"
},
"papermill": {
"duration": 0.356356,
"end_time": "2022-04-18T16:15:30.704716",
"exception": false,
"start_time": "2022-04-18T16:15:30.348360",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from pyspark.ml import Transformer\n",
"from pyspark.ml.param import Param, Params, TypeConverters\n",
"from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasInputCols, HasOutputCols\n",
"from pyspark import keyword_only\n",
"\n",
"\n",
"# Commonly used Params are defined in special classes called Mixin under the pyspark.ml.param.shared module.\n",
"# A Mixin will define the param and the getter for us.\n",
"# We will learn how to define a Mixin later\n",
"class ConstantNullFiller(Transformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols):\n",
" \n",
" # A custom param consists of\n",
" # - parent, which carries the value of the transformer once the transformer is instantiated.\n",
" # Every custom Param we create needs to have Params._dummy() as a parent; this ensures that \n",
" # PySpark will be able to copy and change the Params\n",
" # - name, which is the name of our Param. By convention, we set it to the same name as our Param.\n",
" # - doc, which is the documentation of our Param. This allows us to embed documentation for our Param \n",
" # when the transformer will be used.\n",
" # - typeConverter, which governs the type of the Param. This provides a standardized way to convert an \n",
" # input value to the right type. It also gives a relevant error message if, for example, you expect\n",
" # floating-point number, but the user of the transformer provides a string.\n",
" filler = Param(Params._dummy(), 'filler', 'Fill null(s) with this value', TypeConverters.toFloat)\n",
" \n",
" # As the name suggest keyword_only forces keyword arguments\n",
" # The keyword_only decorator provides the_input_kwargs attribute containing a dictionary of the arguments \n",
" # provided to setParams().\n",
" # The params that was not set manually by the programmer will not exist in the dictionary (instead of None)\n",
" @keyword_only\n",
" def __init__(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None, filler=None):\n",
" # Call initialization of parent classes, including Mixin(s)\n",
" super(ConstantNullFiller, self).__init__()\n",
" # Initialize param. The other params do not need this because they were already initialized in parent classes\n",
" self._setDefault(filler=None)\n",
" # Set the Params with arguments pass\n",
" self.setParams(**self._input_kwargs)\n",
" \n",
" # Getter of filler Param.\n",
" # Other Params do not need getter as they were implemented in parent classes\n",
" def getFiller(self):\n",
" return self.getOrDefault(self.filler)\n",
" \n",
" # Based on the design of every PySpark transformer we have used so far, the simplest way to create \n",
" # setters is as follows: we first create a general method, setParams(), that allows us to change \n",
" # multiple parameters passed as keyword arguments. Then, creating the setter for any other Param \n",
" # will simply call setParams() with the relevant keyword argument.\n",
" @keyword_only\n",
" def setParams(self, inputCol=None, outputCol=None, inputCols=None, outputCols=None, filler=None):\n",
" # We finally use the _set() method provided by the Transformer class to update every Params\n",
" return self._set(**self._input_kwargs)\n",
" \n",
" def setFiller(self, value):\n",
" return self.setParams(filler=value)\n",
" \n",
" def setInputCol(self, value):\n",
" return self.setParams(inputCol=value)\n",
" \n",
" def setOutputCol(self, value):\n",
" return self.setParams(outputCol=value)\n",
" \n",
" def setInputCols(self, value):\n",
" return self.setParams(inputCols=value)\n",
" \n",
" def setOutputCols(self, value):\n",
" return self.setParams(outputCols=value)\n",
" \n",
" # Transformer class has implemented transform() method that allows for an optional argument, params, \n",
" # in case we want to pass a Param map at transformation time\n",
" def _transform(self, dataset):\n",
" # Check params\n",
" if self.isSet('inputCol') and self.isSet('inputCols'):\n",
" raise ValueError('Only \"inputCol\" or \"inputCols\" should be set')\n",
" if not (self.isSet('inputCol') or self.isSet('inputCols')):\n",
" raise ValueError('At least \"inputCol\" or \"inputCols\" must be set')\n",
" if not (self.isSet('outputCol') or self.isSet('outputCols')):\n",
" raise ValueError('At least \"outputCol\" or \"outputCols\" must be set')\n",
" if self.isSet('inputCols') and (len(self.getInputCols()) != len(self.getOutputCols())):\n",
" raise ValueError('Length of \"iputCols\" and \"outputCols\" must be exactly the same')\n",
" # If inputCols == outputCols, no need to try to create new column\n",
" input_cols = self.getInputCols() if self.isSet('inputCols') else [self.getInputCol()]\n",
" output_cols = self.getOutputCols() if self.isSet('outputCols') else [self.getOutputCol()]\n",
" for input_col, output_col in zip(input_cols, output_cols):\n",
" if input_col == output_col:\n",
" continue\n",
" dataset = dataset.withColumn(output_col, dataset[input_col])\n",
" return dataset.fillna(self.getFiller(), output_cols)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "19aa7888",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:31.374166Z",
"iopub.status.busy": "2022-04-18T16:15:31.373155Z",
"iopub.status.idle": "2022-04-18T16:15:31.566804Z",
"shell.execute_reply": "2022-04-18T16:15:31.565919Z",
"shell.execute_reply.started": "2022-04-18T15:25:18.758968Z"
},
"papermill": {
"duration": 0.530875,
"end_time": "2022-04-18T16:15:31.567036",
"exception": false,
"start_time": "2022-04-18T16:15:31.036161",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---+----+-----+-----+\n",
"| a| b| c|b_new|\n",
"+---+----+-----+-----+\n",
"| 1| 4.0| GFG1| 4.0|\n",
"| 5| 1.2| GFC2| 1.2|\n",
"| 5|6.25| null| 6.25|\n",
"| 8|7.11|CLC22| 7.11|\n",
"| 0|0.22| SSv4| 0.22|\n",
"| 0|3.48| MNM0| 3.48|\n",
"| 0|null| TNT2| 0.0|\n",
"+---+----+-----+-----+\n",
"\n"
]
}
],
"source": [
"const_imputer = ConstantNullFiller(filler=0, inputCols=['a', 'b'], outputCols=['a', 'b_new'])\n",
"const_imputer.transform(temp_df).show()"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "341fc858",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:32.309205Z",
"iopub.status.busy": "2022-04-18T16:15:32.308160Z",
"iopub.status.idle": "2022-04-18T16:15:32.324700Z",
"shell.execute_reply": "2022-04-18T16:15:32.325259Z",
"shell.execute_reply.started": "2022-04-18T16:04:24.142067Z"
},
"papermill": {
"duration": 0.361519,
"end_time": "2022-04-18T16:15:32.325445",
"exception": false,
"start_time": "2022-04-18T16:15:31.963926",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from pyspark.ml import Model, Estimator\n",
"\n",
"\n",
"# Our own Mixin. Define the params for transformer/estimator/model\n",
"# We can modulize this for code reuse. And as we separated params and the main class, we can handle them separately.\n",
"class HasStdScorerParams(HasInputCol, HasOutputCol):\n",
" \n",
" fillNull = Param(Params._dummy(), 'fillNull', 'Fill null(s) with 0', TypeConverters.toBoolean)\n",
" \n",
" def __init__(self):\n",
" super(HasStdScorerParams, self).__init__()\n",
" self._setDefault(fillNull=None)\n",
" \n",
" def getFillNull(self):\n",
" return self.getOrDefault(self.fillNull)\n",
" \n",
" @keyword_only\n",
" def setParams(self, inputCol=None, outputCol=None, fillNull=None):\n",
" return self._set(**self._input_kwargs)\n",
" \n",
" def setInputCol(self, value):\n",
" return self.setParams(inputCol=value)\n",
" \n",
" def setOutputCol(self, value):\n",
" return self.setParams(outputCol=value)\n",
" \n",
" def setFillNull(self, value):\n",
" return self.setParams(fillNull=value)\n",
" \n",
"# A model is what an estimator return when we call fit() method\n",
"# Just like a Transformer, it must override _transform method\n",
"class StdScorerModel(Model, HasStdScorerParams):\n",
" \n",
" # Custom params that is only stay inside the Model (and not the estimator)\n",
" mean = Param(Params._dummy(), 'mean', 'Mean of the column', TypeConverters.toFloat)\n",
" stddev = Param(Params._dummy(), 'stddev', 'Standard deviation of the column', TypeConverters.toFloat)\n",
" \n",
" @keyword_only\n",
" def __init__(self, inputCol=None, outputCol=None, fillNull=None, mean=None, stddev=None):\n",
" super(StdScorerModel, self).__init__()\n",
" self._setDefault(mean=None, stddev=None)\n",
" self.setParams(**self._input_kwargs)\n",
" \n",
" def getMean(self):\n",
" return self.getOrDefault(self.mean)\n",
" \n",
" def getStddev(self):\n",
" return self.getOrDefault(self.stddev)\n",
" \n",
" # setParams method has been implemented in HasStdScorerParams but it did not have enough params\n",
" @keyword_only\n",
" def setParams(self, inputCol=None, outputCol=None, fillNull=None, mean=None, stddev=None):\n",
" return self._set(**self._input_kwargs)\n",
" \n",
" def setMean(self, value):\n",
" return self.setParams(mean=value)\n",
" \n",
" def setStddev(self, value):\n",
" return self.setParams(stddev=value)\n",
" \n",
" def _transform(self, dataset):\n",
" if not (self.isSet('inputCol') and self.isSet('outputCol')):\n",
" raise ValueError('Both \"inputCol\" and \"outputCol\" must be set')\n",
" input_col = self.getInputCol()\n",
" output_col = self.getOutputCol()\n",
" dataset = dataset.withColumn(output_col, (dataset[input_col] - self.getMean()) / self.getStddev())\n",
" if self.getFillNull():\n",
" dataset = dataset.fillna(0, output_col)\n",
" return dataset\n",
" \n",
"class StdScorer(Estimator, HasStdScorerParams):\n",
" \n",
" @keyword_only\n",
" def __init__(self, inputCol=None, outputCol=None, fillNull=None):\n",
" super(StdScorer, self).__init__()\n",
" self.setParams(**self._input_kwargs)\n",
" \n",
" # getters and setters have been implemented in HasStdScorerParams\n",
" \n",
" # In the _fit() method, we will return a configured model\n",
" def _fit(self, dataset):\n",
" input_col = self.getInputCol()\n",
" output_col = self.getOutputCol()\n",
" mean, stddev = dataset.agg(F.mean(input_col), F.stddev(input_col)).head()\n",
" return StdScorerModel(inputCol=input_col, outputCol=output_col, fillNull=self.getFillNull(), \n",
" mean=mean, stddev=stddev) "
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "68d5a3ca",
"metadata": {
"execution": {
"iopub.execute_input": "2022-04-18T16:15:32.994852Z",
"iopub.status.busy": "2022-04-18T16:15:32.994188Z",
"iopub.status.idle": "2022-04-18T16:15:33.286038Z",
"shell.execute_reply": "2022-04-18T16:15:33.285343Z",
"shell.execute_reply.started": "2022-04-18T16:04:26.718585Z"
},
"papermill": {
"duration": 0.62962,
"end_time": "2022-04-18T16:15:33.286188",
"exception": false,
"start_time": "2022-04-18T16:15:32.656568",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+----+----+-----+-------------------+\n",
"| a| b| c| a_new|\n",
"+----+----+-----+-------------------+\n",
"| 1| 4.0| GFG1|-0.8559849767220402|\n",
"| 5| 1.2| GFC2| 0.3668507043094459|\n",
"| 5|6.25| null| 0.3668507043094459|\n",
"| 8|7.11|CLC22| 1.2839774650830604|\n",
"|null|0.22| SSv4| 0.0|\n",
"|null|3.48| MNM0| 0.0|\n",
"| 0|null| TNT2|-1.1616938969799118|\n",
"+----+----+-----+-------------------+\n",
"\n"
]
}
],
"source": [
"z = StdScorer(inputCol='a', outputCol='a_new', fillNull=True)\n",
"zmodel = z.fit(temp_df)\n",
"zmodel.transform(temp_df).show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"papermill": {
"default_parameters": {},
"duration": 616.666529,
"end_time": "2022-04-18T16:15:34.534698",
"environment_variables": {},
"exception": null,
"input_path": "__notebook__.ipynb",
"output_path": "__notebook__.ipynb",
"parameters": {},
"start_time": "2022-04-18T16:05:17.868169",
"version": "2.3.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment