Skip to content

Instantly share code, notes, and snippets.

@dannymorris
Created May 20, 2021 21:24
Show Gist options
  • Save dannymorris/65bfd1e920b5c3673d7358cdf9d9753f to your computer and use it in GitHub Desktop.
Save dannymorris/65bfd1e920b5c3673d7358cdf9d9753f to your computer and use it in GitHub Desktop.
sparkml-arules.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "sparkml-arules.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"authorship_tag": "ABX9TyN56cncOgVhCAuH2dCCmQki",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/dannymorris/65bfd1e920b5c3673d7358cdf9d9753f/sparkml-arules.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lhme3LFAYEj_"
},
"source": [
"# Applying Association Rules to continuous features using PySpark and Spark MLlib\n",
"\n",
"## Overview\n",
"\n",
"This notebook shows how to apply Association Rules to continuous features using PySpark and Spark MLlib's [FP-Growth](https://spark.apache.org/docs/latest/ml-frequent-pattern-mining.html#fp-growth) algorithm. \n",
"\n",
"Since Association Rules expects \"itemsets\" as inputs (e.g. [milk, eggs]), continuous features in columnar format cannot be directly supplied as input. Thus, transforming continuous features into categorical itemsets is a necessary pre-processing step. In this notebook, the transformation from continuous features to categorical itemsets is done by mapping continuous feature values to discrete deciles (via [QuantileDiscretizer](https://spark.apache.org/docs/latest/ml-features.html#quantilediscretizer)) then assembling the discrete features into itemset vectors for each entity (row) in the original dataset.\n",
"\n",
"**Visual explanation of the transformation from continuous features to itemsets**\n",
"\n",
"![Imgur](https://i.imgur.com/wWl7RIvm.png)\n",
"\n",
"# Workflow\n",
"\n",
"1. Load CSV file into Spark and read into memory\n",
"2. Transform continuous features into binned, categorical features (deciles in this example)\n",
"3. Assemble categorical features into itemsets\n",
"4. Fit the FP-Growth algorithm to generate association rules\n",
"5. Investigate rules"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uaiz0fDEL1x-"
},
"source": [
"# Install PySpark"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-iOTzB1gd7LE",
"outputId": "a9d4f628-86b7-4ceb-c366-5e356246f563"
},
"source": [
"!pip install pyspark"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting pyspark\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/b0/9d6860891ab14a39d4bddf80ba26ce51c2f9dc4805e5c6978ac0472c120a/pyspark-3.1.1.tar.gz (212.3MB)\n",
"\u001b[K |████████████████████████████████| 212.3MB 67kB/s \n",
"\u001b[?25hCollecting py4j==0.10.9\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/9e/b6/6a4fb90cd235dc8e265a6a2067f2a2c99f0d91787f06aca4bcf7c23f3f80/py4j-0.10.9-py2.py3-none-any.whl (198kB)\n",
"\u001b[K |████████████████████████████████| 204kB 19.4MB/s \n",
"\u001b[?25hBuilding wheels for collected packages: pyspark\n",
" Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pyspark: filename=pyspark-3.1.1-py2.py3-none-any.whl size=212767604 sha256=1ccc2374712de1891cf4e16cf2f3b2379518309bdd3e00217fabc9506f6b4dc9\n",
" Stored in directory: /root/.cache/pip/wheels/0b/90/c0/01de724414ef122bd05f056541fb6a0ecf47c7ca655f8b3c0f\n",
"Successfully built pyspark\n",
"Installing collected packages: py4j, pyspark\n",
"Successfully installed py4j-0.10.9 pyspark-3.1.1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oeBvkC9UMyok"
},
"source": [
"# Load modules and functions"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mLUKdc60jkBz"
},
"source": [
"from pyspark.sql import SparkSession\n",
"from pyspark import SparkFiles\n",
"from pyspark.ml import Pipeline\n",
"from pyspark.ml.fpm import FPGrowth\n",
"from pyspark.ml.feature import QuantileDiscretizer\n",
"\n",
"from pyspark.sql.functions import row_number, array, col, explode, lit, concat, \\\n",
" collect_list, struct, array_contains, create_map, size, regexp_replace\n",
" \n",
"from pyspark.sql.window import Window\n",
"\n",
"import pandas as pd\n",
"import json\n",
"from itertools import chain"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GAX4QcY6iuNh"
},
"source": [
"# general purpose function for obtain a list of column names\n",
"# from a dataframe that begin with `pattern`\n",
"def list_cols_starting_with(df, pattern):\n",
" out = list(filter(lambda x: x.startswith(pattern), df.columns))\n",
" return(out)\n",
"\n",
"# general purpose function to display Spark DF as a pandas DF\n",
"# for improved printing\n",
"def display_pandas(df, limit=5):\n",
" return df.limit(limit).toPandas()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "FMS5JD-kPQsI"
},
"source": [
"# Create SparkSession\n",
"\n",
"`SparkSession` provides a single point of entry to interact with underlying Spark functionality and allows programming Spark with DataFrame and Dataset APIs."
]
},
{
"cell_type": "code",
"metadata": {
"id": "9Gz1Z6ici7KB"
},
"source": [
"spark = SparkSession. \\\n",
" builder. \\\n",
" appName(\"MLlib Clustering\"). \\\n",
" getOrCreate()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "lg_Hnxw8Pgu1"
},
"source": [
"# Read data\n",
"\n",
"In this example, the data resides on the internet in CSV format. Use `addFile` to load the file into Spark and read as a DataFrame using `read.csv`."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 222
},
"id": "M3uGR6qaieqF",
"outputId": "eb3237ec-a339-4b0b-84cc-e5aa064c588a"
},
"source": [
"url = \"https://gist.githubusercontent.com/dannymorris/1bd95ddda1cfe7fd518e5cda01f4ac03/raw/295636f511f0afbd544aece7cbfe771edc01182c/county_data.csv\"\n",
"\n",
"# upload file to Spark\n",
"spark.sparkContext.addFile(url)\n",
"\n",
"# Read file\n",
"df = spark.read.csv(\"file://\"+SparkFiles.get(\"county_data.csv\"), header=True, inferSchema= True)\n",
"\n",
"display_pandas(df)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fips</th>\n",
" <th>state</th>\n",
" <th>name</th>\n",
" <th>party_winner</th>\n",
" <th>trump_pct</th>\n",
" <th>margin</th>\n",
" <th>POPULATION_Total</th>\n",
" <th>AGE_18_29</th>\n",
" <th>AGE_30_44</th>\n",
" <th>AGE_45_59</th>\n",
" <th>AGE_60_Plus</th>\n",
" <th>AGE_18_Plus</th>\n",
" <th>RACE_Total__Asian_alone</th>\n",
" <th>RACE_Total__Black_or_African_American_alone</th>\n",
" <th>RACE_Total__Hispanic_or_Latino</th>\n",
" <th>RACE_Total__White_alone</th>\n",
" <th>GINI_Gini_Index</th>\n",
" <th>INCOME_PER_CAPITA_INCOME_IN_THE_PAST_12_MONTHS__IN_2018_INFLATION_ADJUSTED_DOLLARS_</th>\n",
" <th>UNEMPLOY__16_YEARS_AND_OVER__ASIAN_ALONE_</th>\n",
" <th>UNEMPLOY__16_YEARS_AND_OVER__BLACK_OR_AFRICAN_AMERICAN_ALONE_</th>\n",
" <th>UNEMPLOY__16_YEARS_AND_OVER__HISPANIC_OR_LATINO_</th>\n",
" <th>UNEMPLOY__16_YEARS_AND_OVER__WHITE_ALONE_</th>\n",
" <th>UNEMPLOY_Total_16_YEARS_AND_OVER</th>\n",
" <th>EDU_ATTAIN_Total__Bachelor_s_degree_or_higher</th>\n",
" <th>EDU_ATTAIN_Total__High_school_graduate__includes_equivalency_</th>\n",
" <th>EDU_ATTAIN_Total__Less_than_high_school_diploma</th>\n",
" <th>EDU_ATTAIN_Total__Some_college_or_associate_s_degree</th>\n",
" <th>INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Agriculture__forestry__fishing_and_hunting</th>\n",
" <th>INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Mining__quarrying__and_oil_and_gas_extraction</th>\n",
" <th>INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Accommodation_and_food_services</th>\n",
" <th>INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Arts__entertainment__and_recreation</th>\n",
" <th>INDUSTRY_Total__Construction</th>\n",
" <th>INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Educational_services</th>\n",
" <th>INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Health_care_and_social_assistance</th>\n",
" <th>INDUSTRY_Total__Information</th>\n",
" <th>INDUSTRY_Total__Manufacturing</th>\n",
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Administrative_and_support_and_waste_management_services</th>\n",
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Management_of_companies_and_enterprises</th>\n",
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Professional__scientific__and_technical_services</th>\n",
" <th>INDUSTRY_Total__Public_administration</th>\n",
" <th>INDUSTRY_Total__Retail_trade</th>\n",
" <th>CITIZEN_Estimate__Total__Not_a_U_S__citizen</th>\n",
" <th>CITIZEN_Estimate__Total__U_S__citizen_by_naturalization</th>\n",
" <th>CITIZEN_Estimate__Total__U_S__citizen__born_in_the_United_States</th>\n",
" <th>HEALTH_INSURANCE_No_health_insurance_coverage</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_direct_purchase_health_insurance_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_employer_based_health_insurance_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicaid_means_tested_public_coverage_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicare_coverage_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_TRICARE_military_health_coverage_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_VA_Health_Care_only</th>\n",
" <th>VETERANS_Estimate__Total__Veteran</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>48001</td>\n",
" <td>texas</td>\n",
" <td>Anderson</td>\n",
" <td>republican</td>\n",
" <td>0.786119</td>\n",
" <td>0.580355</td>\n",
" <td>57863</td>\n",
" <td>0.192527</td>\n",
" <td>0.308352</td>\n",
" <td>0.246227</td>\n",
" <td>0.252894</td>\n",
" <td>46648</td>\n",
" <td>0.005513</td>\n",
" <td>0.209823</td>\n",
" <td>0.175276</td>\n",
" <td>0.735340</td>\n",
" <td>0.4225</td>\n",
" <td>16868</td>\n",
" <td>0.0</td>\n",
" <td>0.004287</td>\n",
" <td>0.002272</td>\n",
" <td>0.007760</td>\n",
" <td>0.014320</td>\n",
" <td>0.105299</td>\n",
" <td>0.359608</td>\n",
" <td>0.232636</td>\n",
" <td>0.308481</td>\n",
" <td>0.007846</td>\n",
" <td>0.028490</td>\n",
" <td>0.022359</td>\n",
" <td>0.002165</td>\n",
" <td>0.022380</td>\n",
" <td>0.029905</td>\n",
" <td>0.057473</td>\n",
" <td>0.002037</td>\n",
" <td>0.025424</td>\n",
" <td>0.021373</td>\n",
" <td>0.000000</td>\n",
" <td>0.010912</td>\n",
" <td>0.038608</td>\n",
" <td>0.077860</td>\n",
" <td>0.041546</td>\n",
" <td>0.020946</td>\n",
" <td>0.928400</td>\n",
" <td>0.114324</td>\n",
" <td>0.028040</td>\n",
" <td>0.306080</td>\n",
" <td>0.028040</td>\n",
" <td>0.065833</td>\n",
" <td>0.005638</td>\n",
" <td>0.006088</td>\n",
" <td>0.086006</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>48003</td>\n",
" <td>texas</td>\n",
" <td>Andrews</td>\n",
" <td>republican</td>\n",
" <td>0.843084</td>\n",
" <td>0.698107</td>\n",
" <td>17818</td>\n",
" <td>0.245792</td>\n",
" <td>0.287747</td>\n",
" <td>0.246605</td>\n",
" <td>0.219855</td>\n",
" <td>12299</td>\n",
" <td>0.003536</td>\n",
" <td>0.019811</td>\n",
" <td>0.560052</td>\n",
" <td>0.924009</td>\n",
" <td>0.4506</td>\n",
" <td>31190</td>\n",
" <td>0.0</td>\n",
" <td>0.012196</td>\n",
" <td>0.015611</td>\n",
" <td>0.018863</td>\n",
" <td>0.046670</td>\n",
" <td>0.102529</td>\n",
" <td>0.461338</td>\n",
" <td>0.368810</td>\n",
" <td>0.319457</td>\n",
" <td>0.008537</td>\n",
" <td>0.167900</td>\n",
" <td>0.035125</td>\n",
" <td>0.003659</td>\n",
" <td>0.051955</td>\n",
" <td>0.036751</td>\n",
" <td>0.065696</td>\n",
" <td>0.009513</td>\n",
" <td>0.036100</td>\n",
" <td>0.022766</td>\n",
" <td>0.003252</td>\n",
" <td>0.014798</td>\n",
" <td>0.013172</td>\n",
" <td>0.077323</td>\n",
" <td>0.098047</td>\n",
" <td>0.047368</td>\n",
" <td>0.850208</td>\n",
" <td>0.175055</td>\n",
" <td>0.043256</td>\n",
" <td>0.511342</td>\n",
" <td>0.030084</td>\n",
" <td>0.055940</td>\n",
" <td>0.000000</td>\n",
" <td>0.003334</td>\n",
" <td>0.055696</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>48005</td>\n",
" <td>texas</td>\n",
" <td>Angelina</td>\n",
" <td>republican</td>\n",
" <td>0.723981</td>\n",
" <td>0.460148</td>\n",
" <td>87607</td>\n",
" <td>0.208537</td>\n",
" <td>0.245155</td>\n",
" <td>0.258604</td>\n",
" <td>0.287704</td>\n",
" <td>64914</td>\n",
" <td>0.011495</td>\n",
" <td>0.147956</td>\n",
" <td>0.218864</td>\n",
" <td>0.791855</td>\n",
" <td>0.4495</td>\n",
" <td>22322</td>\n",
" <td>0.0</td>\n",
" <td>0.012416</td>\n",
" <td>0.009582</td>\n",
" <td>0.032027</td>\n",
" <td>0.054025</td>\n",
" <td>0.156638</td>\n",
" <td>0.314755</td>\n",
" <td>0.215177</td>\n",
" <td>0.306251</td>\n",
" <td>0.011092</td>\n",
" <td>0.010768</td>\n",
" <td>0.044012</td>\n",
" <td>0.005099</td>\n",
" <td>0.037681</td>\n",
" <td>0.049589</td>\n",
" <td>0.092831</td>\n",
" <td>0.002927</td>\n",
" <td>0.064624</td>\n",
" <td>0.024525</td>\n",
" <td>0.000308</td>\n",
" <td>0.018578</td>\n",
" <td>0.023323</td>\n",
" <td>0.073836</td>\n",
" <td>0.058420</td>\n",
" <td>0.026197</td>\n",
" <td>0.909745</td>\n",
" <td>0.204409</td>\n",
" <td>0.058092</td>\n",
" <td>0.351111</td>\n",
" <td>0.057784</td>\n",
" <td>0.073559</td>\n",
" <td>0.001833</td>\n",
" <td>0.006085</td>\n",
" <td>0.091336</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>48007</td>\n",
" <td>texas</td>\n",
" <td>Aransas</td>\n",
" <td>republican</td>\n",
" <td>0.751811</td>\n",
" <td>0.514525</td>\n",
" <td>24763</td>\n",
" <td>0.147476</td>\n",
" <td>0.171872</td>\n",
" <td>0.252445</td>\n",
" <td>0.428208</td>\n",
" <td>20044</td>\n",
" <td>0.019707</td>\n",
" <td>0.015386</td>\n",
" <td>0.272826</td>\n",
" <td>0.892622</td>\n",
" <td>0.5351</td>\n",
" <td>30939</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.006186</td>\n",
" <td>0.031581</td>\n",
" <td>0.037767</td>\n",
" <td>0.211684</td>\n",
" <td>0.299242</td>\n",
" <td>0.185242</td>\n",
" <td>0.339453</td>\n",
" <td>0.005887</td>\n",
" <td>0.018010</td>\n",
" <td>0.057025</td>\n",
" <td>0.023249</td>\n",
" <td>0.062862</td>\n",
" <td>0.043205</td>\n",
" <td>0.052784</td>\n",
" <td>0.002644</td>\n",
" <td>0.025244</td>\n",
" <td>0.019407</td>\n",
" <td>0.000000</td>\n",
" <td>0.022051</td>\n",
" <td>0.030682</td>\n",
" <td>0.054131</td>\n",
" <td>0.039979</td>\n",
" <td>0.031579</td>\n",
" <td>0.914671</td>\n",
" <td>0.206945</td>\n",
" <td>0.079824</td>\n",
" <td>0.252744</td>\n",
" <td>0.026542</td>\n",
" <td>0.108162</td>\n",
" <td>0.007334</td>\n",
" <td>0.004690</td>\n",
" <td>0.128916</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>48009</td>\n",
" <td>texas</td>\n",
" <td>Archer</td>\n",
" <td>republican</td>\n",
" <td>0.896580</td>\n",
" <td>0.803586</td>\n",
" <td>8789</td>\n",
" <td>0.163153</td>\n",
" <td>0.208085</td>\n",
" <td>0.281809</td>\n",
" <td>0.346954</td>\n",
" <td>6877</td>\n",
" <td>0.005348</td>\n",
" <td>0.009216</td>\n",
" <td>0.082717</td>\n",
" <td>0.948231</td>\n",
" <td>0.4316</td>\n",
" <td>31806</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.002472</td>\n",
" <td>0.020503</td>\n",
" <td>0.022975</td>\n",
" <td>0.210121</td>\n",
" <td>0.305511</td>\n",
" <td>0.096990</td>\n",
" <td>0.310310</td>\n",
" <td>0.034026</td>\n",
" <td>0.028210</td>\n",
" <td>0.014977</td>\n",
" <td>0.008579</td>\n",
" <td>0.046968</td>\n",
" <td>0.042751</td>\n",
" <td>0.123164</td>\n",
" <td>0.006107</td>\n",
" <td>0.049149</td>\n",
" <td>0.021085</td>\n",
" <td>0.000000</td>\n",
" <td>0.027774</td>\n",
" <td>0.033881</td>\n",
" <td>0.080413</td>\n",
" <td>0.019229</td>\n",
" <td>0.006713</td>\n",
" <td>0.971100</td>\n",
" <td>0.127963</td>\n",
" <td>0.077359</td>\n",
" <td>0.422132</td>\n",
" <td>0.025447</td>\n",
" <td>0.073724</td>\n",
" <td>0.011778</td>\n",
" <td>0.000145</td>\n",
" <td>0.090883</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fips ... VETERANS_Estimate__Total__Veteran\n",
"0 48001 ... 0.086006\n",
"1 48003 ... 0.055696\n",
"2 48005 ... 0.091336\n",
"3 48007 ... 0.128916\n",
"4 48009 ... 0.090883\n",
"\n",
"[5 rows x 52 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1bcZqbGVQJYa"
},
"source": [
"# Select features for Association Rules mining\n",
"\n",
"To avoid handcrafting the entire feature list, features with certain prefixes (tags) in the feature name are selected from the original DataFrame."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 222
},
"id": "UDMlKUGrkLjU",
"outputId": "8e181157-414d-47a9-9e81-3a9dd45aa67e"
},
"source": [
"# locate columns whose names begin with one of these prefixes (tags)\n",
"feature_tags = ('trump', 'AGE', 'RACE', 'GINI', 'INCOME', 'UNEMPLOY', 'EDU', 'INDUSTRY', 'CITIZEN', 'HEALTH', 'VETERANS')\n",
"feature_list = list_cols_starting_with(df=df, pattern=feature_tags)\n",
"\n",
"# drop unused colum AGE_18_Plus\n",
"feature_list = [i for i in feature_list if i != \"AGE_18_Plus\"]\n",
"\n",
"# select features\n",
"features_df = df.select(feature_list)\n",
"\n",
"display_pandas(features_df)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>trump_pct</th>\n",
" <th>AGE_18_29</th>\n",
" <th>AGE_30_44</th>\n",
" <th>AGE_45_59</th>\n",
" <th>AGE_60_Plus</th>\n",
" <th>RACE_Total__Asian_alone</th>\n",
" <th>RACE_Total__Black_or_African_American_alone</th>\n",
" <th>RACE_Total__Hispanic_or_Latino</th>\n",
" <th>RACE_Total__White_alone</th>\n",
" <th>GINI_Gini_Index</th>\n",
" <th>INCOME_PER_CAPITA_INCOME_IN_THE_PAST_12_MONTHS__IN_2018_INFLATION_ADJUSTED_DOLLARS_</th>\n",
" <th>UNEMPLOY__16_YEARS_AND_OVER__ASIAN_ALONE_</th>\n",
" <th>UNEMPLOY__16_YEARS_AND_OVER__BLACK_OR_AFRICAN_AMERICAN_ALONE_</th>\n",
" <th>UNEMPLOY__16_YEARS_AND_OVER__HISPANIC_OR_LATINO_</th>\n",
" <th>UNEMPLOY__16_YEARS_AND_OVER__WHITE_ALONE_</th>\n",
" <th>UNEMPLOY_Total_16_YEARS_AND_OVER</th>\n",
" <th>EDU_ATTAIN_Total__Bachelor_s_degree_or_higher</th>\n",
" <th>EDU_ATTAIN_Total__High_school_graduate__includes_equivalency_</th>\n",
" <th>EDU_ATTAIN_Total__Less_than_high_school_diploma</th>\n",
" <th>EDU_ATTAIN_Total__Some_college_or_associate_s_degree</th>\n",
" <th>INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Agriculture__forestry__fishing_and_hunting</th>\n",
" <th>INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Mining__quarrying__and_oil_and_gas_extraction</th>\n",
" <th>INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Accommodation_and_food_services</th>\n",
" <th>INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Arts__entertainment__and_recreation</th>\n",
" <th>INDUSTRY_Total__Construction</th>\n",
" <th>INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Educational_services</th>\n",
" <th>INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Health_care_and_social_assistance</th>\n",
" <th>INDUSTRY_Total__Information</th>\n",
" <th>INDUSTRY_Total__Manufacturing</th>\n",
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Administrative_and_support_and_waste_management_services</th>\n",
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Management_of_companies_and_enterprises</th>\n",
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Professional__scientific__and_technical_services</th>\n",
" <th>INDUSTRY_Total__Public_administration</th>\n",
" <th>INDUSTRY_Total__Retail_trade</th>\n",
" <th>CITIZEN_Estimate__Total__Not_a_U_S__citizen</th>\n",
" <th>CITIZEN_Estimate__Total__U_S__citizen_by_naturalization</th>\n",
" <th>CITIZEN_Estimate__Total__U_S__citizen__born_in_the_United_States</th>\n",
" <th>HEALTH_INSURANCE_No_health_insurance_coverage</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_direct_purchase_health_insurance_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_employer_based_health_insurance_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicaid_means_tested_public_coverage_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicare_coverage_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_TRICARE_military_health_coverage_only</th>\n",
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_VA_Health_Care_only</th>\n",
" <th>VETERANS_Estimate__Total__Veteran</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.786119</td>\n",
" <td>0.192527</td>\n",
" <td>0.308352</td>\n",
" <td>0.246227</td>\n",
" <td>0.252894</td>\n",
" <td>0.005513</td>\n",
" <td>0.209823</td>\n",
" <td>0.175276</td>\n",
" <td>0.735340</td>\n",
" <td>0.4225</td>\n",
" <td>16868</td>\n",
" <td>0.0</td>\n",
" <td>0.004287</td>\n",
" <td>0.002272</td>\n",
" <td>0.007760</td>\n",
" <td>0.014320</td>\n",
" <td>0.105299</td>\n",
" <td>0.359608</td>\n",
" <td>0.232636</td>\n",
" <td>0.308481</td>\n",
" <td>0.007846</td>\n",
" <td>0.028490</td>\n",
" <td>0.022359</td>\n",
" <td>0.002165</td>\n",
" <td>0.022380</td>\n",
" <td>0.029905</td>\n",
" <td>0.057473</td>\n",
" <td>0.002037</td>\n",
" <td>0.025424</td>\n",
" <td>0.021373</td>\n",
" <td>0.000000</td>\n",
" <td>0.010912</td>\n",
" <td>0.038608</td>\n",
" <td>0.077860</td>\n",
" <td>0.041546</td>\n",
" <td>0.020946</td>\n",
" <td>0.928400</td>\n",
" <td>0.114324</td>\n",
" <td>0.028040</td>\n",
" <td>0.306080</td>\n",
" <td>0.028040</td>\n",
" <td>0.065833</td>\n",
" <td>0.005638</td>\n",
" <td>0.006088</td>\n",
" <td>0.086006</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.843084</td>\n",
" <td>0.245792</td>\n",
" <td>0.287747</td>\n",
" <td>0.246605</td>\n",
" <td>0.219855</td>\n",
" <td>0.003536</td>\n",
" <td>0.019811</td>\n",
" <td>0.560052</td>\n",
" <td>0.924009</td>\n",
" <td>0.4506</td>\n",
" <td>31190</td>\n",
" <td>0.0</td>\n",
" <td>0.012196</td>\n",
" <td>0.015611</td>\n",
" <td>0.018863</td>\n",
" <td>0.046670</td>\n",
" <td>0.102529</td>\n",
" <td>0.461338</td>\n",
" <td>0.368810</td>\n",
" <td>0.319457</td>\n",
" <td>0.008537</td>\n",
" <td>0.167900</td>\n",
" <td>0.035125</td>\n",
" <td>0.003659</td>\n",
" <td>0.051955</td>\n",
" <td>0.036751</td>\n",
" <td>0.065696</td>\n",
" <td>0.009513</td>\n",
" <td>0.036100</td>\n",
" <td>0.022766</td>\n",
" <td>0.003252</td>\n",
" <td>0.014798</td>\n",
" <td>0.013172</td>\n",
" <td>0.077323</td>\n",
" <td>0.098047</td>\n",
" <td>0.047368</td>\n",
" <td>0.850208</td>\n",
" <td>0.175055</td>\n",
" <td>0.043256</td>\n",
" <td>0.511342</td>\n",
" <td>0.030084</td>\n",
" <td>0.055940</td>\n",
" <td>0.000000</td>\n",
" <td>0.003334</td>\n",
" <td>0.055696</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.723981</td>\n",
" <td>0.208537</td>\n",
" <td>0.245155</td>\n",
" <td>0.258604</td>\n",
" <td>0.287704</td>\n",
" <td>0.011495</td>\n",
" <td>0.147956</td>\n",
" <td>0.218864</td>\n",
" <td>0.791855</td>\n",
" <td>0.4495</td>\n",
" <td>22322</td>\n",
" <td>0.0</td>\n",
" <td>0.012416</td>\n",
" <td>0.009582</td>\n",
" <td>0.032027</td>\n",
" <td>0.054025</td>\n",
" <td>0.156638</td>\n",
" <td>0.314755</td>\n",
" <td>0.215177</td>\n",
" <td>0.306251</td>\n",
" <td>0.011092</td>\n",
" <td>0.010768</td>\n",
" <td>0.044012</td>\n",
" <td>0.005099</td>\n",
" <td>0.037681</td>\n",
" <td>0.049589</td>\n",
" <td>0.092831</td>\n",
" <td>0.002927</td>\n",
" <td>0.064624</td>\n",
" <td>0.024525</td>\n",
" <td>0.000308</td>\n",
" <td>0.018578</td>\n",
" <td>0.023323</td>\n",
" <td>0.073836</td>\n",
" <td>0.058420</td>\n",
" <td>0.026197</td>\n",
" <td>0.909745</td>\n",
" <td>0.204409</td>\n",
" <td>0.058092</td>\n",
" <td>0.351111</td>\n",
" <td>0.057784</td>\n",
" <td>0.073559</td>\n",
" <td>0.001833</td>\n",
" <td>0.006085</td>\n",
" <td>0.091336</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.751811</td>\n",
" <td>0.147476</td>\n",
" <td>0.171872</td>\n",
" <td>0.252445</td>\n",
" <td>0.428208</td>\n",
" <td>0.019707</td>\n",
" <td>0.015386</td>\n",
" <td>0.272826</td>\n",
" <td>0.892622</td>\n",
" <td>0.5351</td>\n",
" <td>30939</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.006186</td>\n",
" <td>0.031581</td>\n",
" <td>0.037767</td>\n",
" <td>0.211684</td>\n",
" <td>0.299242</td>\n",
" <td>0.185242</td>\n",
" <td>0.339453</td>\n",
" <td>0.005887</td>\n",
" <td>0.018010</td>\n",
" <td>0.057025</td>\n",
" <td>0.023249</td>\n",
" <td>0.062862</td>\n",
" <td>0.043205</td>\n",
" <td>0.052784</td>\n",
" <td>0.002644</td>\n",
" <td>0.025244</td>\n",
" <td>0.019407</td>\n",
" <td>0.000000</td>\n",
" <td>0.022051</td>\n",
" <td>0.030682</td>\n",
" <td>0.054131</td>\n",
" <td>0.039979</td>\n",
" <td>0.031579</td>\n",
" <td>0.914671</td>\n",
" <td>0.206945</td>\n",
" <td>0.079824</td>\n",
" <td>0.252744</td>\n",
" <td>0.026542</td>\n",
" <td>0.108162</td>\n",
" <td>0.007334</td>\n",
" <td>0.004690</td>\n",
" <td>0.128916</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.896580</td>\n",
" <td>0.163153</td>\n",
" <td>0.208085</td>\n",
" <td>0.281809</td>\n",
" <td>0.346954</td>\n",
" <td>0.005348</td>\n",
" <td>0.009216</td>\n",
" <td>0.082717</td>\n",
" <td>0.948231</td>\n",
" <td>0.4316</td>\n",
" <td>31806</td>\n",
" <td>0.0</td>\n",
" <td>0.000000</td>\n",
" <td>0.002472</td>\n",
" <td>0.020503</td>\n",
" <td>0.022975</td>\n",
" <td>0.210121</td>\n",
" <td>0.305511</td>\n",
" <td>0.096990</td>\n",
" <td>0.310310</td>\n",
" <td>0.034026</td>\n",
" <td>0.028210</td>\n",
" <td>0.014977</td>\n",
" <td>0.008579</td>\n",
" <td>0.046968</td>\n",
" <td>0.042751</td>\n",
" <td>0.123164</td>\n",
" <td>0.006107</td>\n",
" <td>0.049149</td>\n",
" <td>0.021085</td>\n",
" <td>0.000000</td>\n",
" <td>0.027774</td>\n",
" <td>0.033881</td>\n",
" <td>0.080413</td>\n",
" <td>0.019229</td>\n",
" <td>0.006713</td>\n",
" <td>0.971100</td>\n",
" <td>0.127963</td>\n",
" <td>0.077359</td>\n",
" <td>0.422132</td>\n",
" <td>0.025447</td>\n",
" <td>0.073724</td>\n",
" <td>0.011778</td>\n",
" <td>0.000145</td>\n",
" <td>0.090883</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" trump_pct ... VETERANS_Estimate__Total__Veteran\n",
"0 0.786119 ... 0.086006\n",
"1 0.843084 ... 0.055696\n",
"2 0.723981 ... 0.091336\n",
"3 0.751811 ... 0.128916\n",
"4 0.896580 ... 0.090883\n",
"\n",
"[5 rows x 45 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 28
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rD1yc70LQcsf"
},
"source": [
"# Transform continuous features to categorical\n",
"\n",
"[QuantileDiscretizer](https://spark.apache.org/docs/latest/ml-features.html#quantilediscretizer) maps continuous features values to discrete (categorical) deciles.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-jRW-GzIQfSI",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 222
},
"outputId": "68a5ba3f-fe62-4e95-913f-e56b576dfb75"
},
"source": [
"# fit QuantileDiscretizer to all and transform\n",
"discretizer = [QuantileDiscretizer(inputCol=x, outputCol=\"Quantile_\"+x, numBuckets=10) for x in feature_list]\n",
"discretizer_results = Pipeline(stages=discretizer).fit(features_df).transform(features_df)\n",
"\n",
"# select transformed columns\n",
"discrete_cols = list_cols_starting_with(discretizer_results, \"Quantile_\")\n",
"discrete_df = discretizer_results.select(discrete_cols)\n",
"\n",
"display_pandas(discrete_df)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Quantile_trump_pct</th>\n",
" <th>Quantile_AGE_18_29</th>\n",
" <th>Quantile_AGE_30_44</th>\n",
" <th>Quantile_AGE_45_59</th>\n",
" <th>Quantile_AGE_60_Plus</th>\n",
" <th>Quantile_RACE_Total__Asian_alone</th>\n",
" <th>Quantile_RACE_Total__Black_or_African_American_alone</th>\n",
" <th>Quantile_RACE_Total__Hispanic_or_Latino</th>\n",
" <th>Quantile_RACE_Total__White_alone</th>\n",
" <th>Quantile_GINI_Gini_Index</th>\n",
" <th>Quantile_INCOME_PER_CAPITA_INCOME_IN_THE_PAST_12_MONTHS__IN_2018_INFLATION_ADJUSTED_DOLLARS_</th>\n",
" <th>Quantile_UNEMPLOY__16_YEARS_AND_OVER__ASIAN_ALONE_</th>\n",
" <th>Quantile_UNEMPLOY__16_YEARS_AND_OVER__BLACK_OR_AFRICAN_AMERICAN_ALONE_</th>\n",
" <th>Quantile_UNEMPLOY__16_YEARS_AND_OVER__HISPANIC_OR_LATINO_</th>\n",
" <th>Quantile_UNEMPLOY__16_YEARS_AND_OVER__WHITE_ALONE_</th>\n",
" <th>Quantile_UNEMPLOY_Total_16_YEARS_AND_OVER</th>\n",
" <th>Quantile_EDU_ATTAIN_Total__Bachelor_s_degree_or_higher</th>\n",
" <th>Quantile_EDU_ATTAIN_Total__High_school_graduate__includes_equivalency_</th>\n",
" <th>Quantile_EDU_ATTAIN_Total__Less_than_high_school_diploma</th>\n",
" <th>Quantile_EDU_ATTAIN_Total__Some_college_or_associate_s_degree</th>\n",
" <th>Quantile_INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Agriculture__forestry__fishing_and_hunting</th>\n",
" <th>Quantile_INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Mining__quarrying__and_oil_and_gas_extraction</th>\n",
" <th>Quantile_INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Accommodation_and_food_services</th>\n",
" <th>Quantile_INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Arts__entertainment__and_recreation</th>\n",
" <th>Quantile_INDUSTRY_Total__Construction</th>\n",
" <th>Quantile_INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Educational_services</th>\n",
" <th>Quantile_INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Health_care_and_social_assistance</th>\n",
" <th>Quantile_INDUSTRY_Total__Information</th>\n",
" <th>Quantile_INDUSTRY_Total__Manufacturing</th>\n",
" <th>Quantile_INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Administrative_and_support_and_waste_management_services</th>\n",
" <th>Quantile_INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Management_of_companies_and_enterprises</th>\n",
" <th>Quantile_INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Professional__scientific__and_technical_services</th>\n",
" <th>Quantile_INDUSTRY_Total__Public_administration</th>\n",
" <th>Quantile_INDUSTRY_Total__Retail_trade</th>\n",
" <th>Quantile_CITIZEN_Estimate__Total__Not_a_U_S__citizen</th>\n",
" <th>Quantile_CITIZEN_Estimate__Total__U_S__citizen_by_naturalization</th>\n",
" <th>Quantile_CITIZEN_Estimate__Total__U_S__citizen__born_in_the_United_States</th>\n",
" <th>Quantile_HEALTH_INSURANCE_No_health_insurance_coverage</th>\n",
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_direct_purchase_health_insurance_only</th>\n",
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_employer_based_health_insurance_only</th>\n",
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicaid_means_tested_public_coverage_only</th>\n",
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicare_coverage_only</th>\n",
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_TRICARE_military_health_coverage_only</th>\n",
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_VA_Health_Care_only</th>\n",
" <th>Quantile_VETERANS_Estimate__Total__Veteran</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>7.0</td>\n",
" <td>6.0</td>\n",
" <td>9.0</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>4.0</td>\n",
" <td>8.0</td>\n",
" <td>8.0</td>\n",
" <td>2.0</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>5.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>7.0</td>\n",
" <td>9.0</td>\n",
" <td>7.0</td>\n",
" <td>2.0</td>\n",
" <td>9.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>6.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" <td>8.0</td>\n",
" <td>8.0</td>\n",
" <td>8.0</td>\n",
" <td>7.0</td>\n",
" <td>2.0</td>\n",
" <td>6.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>3.0</td>\n",
" <td>6.0</td>\n",
" <td>8.0</td>\n",
" <td>4.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>9.0</td>\n",
" <td>9.0</td>\n",
" <td>9.0</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
" <td>4.0</td>\n",
" <td>9.0</td>\n",
" <td>5.0</td>\n",
" <td>5.0</td>\n",
" <td>7.0</td>\n",
" <td>1.0</td>\n",
" <td>6.0</td>\n",
" <td>8.0</td>\n",
" <td>2.0</td>\n",
" <td>8.0</td>\n",
" <td>0.0</td>\n",
" <td>9.0</td>\n",
" <td>9.0</td>\n",
" <td>8.0</td>\n",
" <td>3.0</td>\n",
" <td>9.0</td>\n",
" <td>4.0</td>\n",
" <td>1.0</td>\n",
" <td>8.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" <td>7.0</td>\n",
" <td>2.0</td>\n",
" <td>7.0</td>\n",
" <td>5.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>8.0</td>\n",
" <td>9.0</td>\n",
" <td>9.0</td>\n",
" <td>0.0</td>\n",
" <td>8.0</td>\n",
" <td>1.0</td>\n",
" <td>9.0</td>\n",
" <td>1.0</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>4.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6.0</td>\n",
" <td>7.0</td>\n",
" <td>7.0</td>\n",
" <td>4.0</td>\n",
" <td>2.0</td>\n",
" <td>7.0</td>\n",
" <td>8.0</td>\n",
" <td>8.0</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>6.0</td>\n",
" <td>8.0</td>\n",
" <td>7.0</td>\n",
" <td>9.0</td>\n",
" <td>3.0</td>\n",
" <td>4.0</td>\n",
" <td>8.0</td>\n",
" <td>7.0</td>\n",
" <td>3.0</td>\n",
" <td>8.0</td>\n",
" <td>7.0</td>\n",
" <td>2.0</td>\n",
" <td>4.0</td>\n",
" <td>5.0</td>\n",
" <td>7.0</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>8.0</td>\n",
" <td>2.0</td>\n",
" <td>5.0</td>\n",
" <td>3.0</td>\n",
" <td>7.0</td>\n",
" <td>8.0</td>\n",
" <td>7.0</td>\n",
" <td>1.0</td>\n",
" <td>9.0</td>\n",
" <td>4.0</td>\n",
" <td>3.0</td>\n",
" <td>4.0</td>\n",
" <td>5.0</td>\n",
" <td>1.0</td>\n",
" <td>8.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>6.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
" <td>9.0</td>\n",
" <td>8.0</td>\n",
" <td>4.0</td>\n",
" <td>9.0</td>\n",
" <td>4.0</td>\n",
" <td>9.0</td>\n",
" <td>7.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>7.0</td>\n",
" <td>7.0</td>\n",
" <td>6.0</td>\n",
" <td>6.0</td>\n",
" <td>4.0</td>\n",
" <td>8.0</td>\n",
" <td>8.0</td>\n",
" <td>2.0</td>\n",
" <td>8.0</td>\n",
" <td>9.0</td>\n",
" <td>9.0</td>\n",
" <td>9.0</td>\n",
" <td>3.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>1.0</td>\n",
" <td>6.0</td>\n",
" <td>6.0</td>\n",
" <td>2.0</td>\n",
" <td>8.0</td>\n",
" <td>8.0</td>\n",
" <td>1.0</td>\n",
" <td>9.0</td>\n",
" <td>7.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>9.0</td>\n",
" <td>7.0</td>\n",
" <td>6.0</td>\n",
" <td>9.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>9.0</td>\n",
" <td>2.0</td>\n",
" <td>2.0</td>\n",
" <td>8.0</td>\n",
" <td>6.0</td>\n",
" <td>4.0</td>\n",
" <td>3.0</td>\n",
" <td>7.0</td>\n",
" <td>7.0</td>\n",
" <td>3.0</td>\n",
" <td>8.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>3.0</td>\n",
" <td>2.0</td>\n",
" <td>6.0</td>\n",
" <td>4.0</td>\n",
" <td>3.0</td>\n",
" <td>7.0</td>\n",
" <td>7.0</td>\n",
" <td>9.0</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>7.0</td>\n",
" <td>3.0</td>\n",
" <td>9.0</td>\n",
" <td>3.0</td>\n",
" <td>3.0</td>\n",
" <td>6.0</td>\n",
" <td>1.0</td>\n",
" <td>7.0</td>\n",
" <td>7.0</td>\n",
" <td>8.0</td>\n",
" <td>5.0</td>\n",
" <td>3.0</td>\n",
" <td>5.0</td>\n",
" <td>7.0</td>\n",
" <td>7.0</td>\n",
" <td>6.0</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>8.0</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Quantile_trump_pct ... Quantile_VETERANS_Estimate__Total__Veteran\n",
"0 7.0 ... 4.0\n",
"1 9.0 ... 0.0\n",
"2 6.0 ... 6.0\n",
"3 6.0 ... 9.0\n",
"4 9.0 ... 5.0\n",
"\n",
"[5 rows x 45 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 29
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EhDjLSVwWCdS"
},
"source": [
"# Assemble discrete features into itemsets\n",
"\n",
"1. Reshape features from wide to long format.\n",
"2. Create individual items by concatenating the feature name with the discrete feature value.\n",
"3. Assemble the itemsets for each row by collecting the individual items into an array."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lLgKqzB-nLtc"
},
"source": [
"## Reshape features from wide to long format\n",
"\n",
"The function `to_explode` casts a DataFrame from wide to long. This function behaves similarly to `tidyr::gather()` in R and `.pivot` in Python."
]
},
{
"cell_type": "code",
"metadata": {
"id": "3b_D0e_DIwMR"
},
"source": [
"def to_explode(df, by):\n",
" cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))\n",
" kvs = explode(array([\n",
" struct(lit(c).alias(\"feature\"), col(c).alias(\"value\")) for c in cols\n",
" ])).alias(\"kvs\")\n",
" return df.select(by + [kvs]).select(by + [\"kvs.feature\", \"kvs.value\"])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "I5XZ28IaLljL"
},
"source": [
"To retain the row-wise relationships after the data is reshaped, a column containing the row number is appended using the custom `append_row_number` function."
]
},
{
"cell_type": "code",
"metadata": {
"id": "uBezhWy_k9Mr"
},
"source": [
"def append_row_number(df):\n",
" w = Window().orderBy(lit('A'))\n",
" out = df.withColumn(\"row_number\", row_number().over(w))\n",
" return(out)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ba4qbQq6oFkT"
},
"source": [
"Putting it all together with a custom `wide_to_long` function."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 202
},
"id": "ziTAUwrhmPfo",
"outputId": "4d898be8-5810-466a-f74b-2686054d937c"
},
"source": [
"# cast from wide to long\n",
"def wide_to_long(df):\n",
" df_numbered = append_row_number(df)\n",
" df_long = to_explode(df_numbered, by=['row_number'])\n",
" return(df_long)\n",
"\n",
"long_df = wide_to_long(df = discrete_df)\n",
"\n",
"display_pandas(long_df)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>row_number</th>\n",
" <th>feature</th>\n",
" <th>value</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>Quantile_trump_pct</td>\n",
" <td>7.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>Quantile_AGE_18_29</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>Quantile_AGE_30_44</td>\n",
" <td>9.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>Quantile_AGE_45_59</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>Quantile_AGE_60_Plus</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" row_number feature value\n",
"0 1 Quantile_trump_pct 7.0\n",
"1 1 Quantile_AGE_18_29 6.0\n",
"2 1 Quantile_AGE_30_44 9.0\n",
"3 1 Quantile_AGE_45_59 2.0\n",
"4 1 Quantile_AGE_60_Plus 1.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 30
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cksmYdfDJnGc"
},
"source": [
"## Convert decile integers to intuitive labels\n",
"\n",
"To improve the interpretability of the final solution, discrete decile values (e.g. `1.0`) are transformed to intuitive labels (e.g. `Extremeley Low`)."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 202
},
"id": "RB6Ag9Uj3_n2",
"outputId": "c61a0def-91bf-43cf-bd22-de10c2263450"
},
"source": [
"mapping = {\n",
" 0.0: 'Extremely_Low',\n",
" 1.0: 'Very_Low',\n",
" 2.0: 'Moderately_Low',\n",
" 3.0: 'Moderately_Low',\n",
" 4.0: 'Moderate',\n",
" 5.0: 'Moderate',\n",
" 6.0: 'Moderately_High',\n",
" 7.0: 'Moderately_High',\n",
" 8.0: 'Very_High',\n",
" 9.0: 'Extremely_High'\n",
"}\n",
"\n",
"def int_to_labels(column, mapping):\n",
" mapping_expr = create_map([lit(x) for x in chain(*mapping.items())])\n",
" labels = mapping_expr.getItem(col(column))\n",
" return(labels)\n",
"\n",
"long_df_labeled = long_df.withColumn(\"label\", int_to_labels(\"value\", mapping)).drop(\"value\")\n",
"\n",
"def display_pandas(df, limit=5):\n",
" return df.limit(limit).toPandas()\n",
"\n",
"display_pandas(long_df_labeled, limit=5)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>row_number</th>\n",
" <th>feature</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>Quantile_trump_pct</td>\n",
" <td>Moderately_High</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>Quantile_AGE_18_29</td>\n",
" <td>Moderately_High</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>Quantile_AGE_30_44</td>\n",
" <td>Extremely_High</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>Quantile_AGE_45_59</td>\n",
" <td>Moderately_Low</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>Quantile_AGE_60_Plus</td>\n",
" <td>Very_Low</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" row_number feature label\n",
"0 1 Quantile_trump_pct Moderately_High\n",
"1 1 Quantile_AGE_18_29 Moderately_High\n",
"2 1 Quantile_AGE_30_44 Extremely_High\n",
"3 1 Quantile_AGE_45_59 Moderately_Low\n",
"4 1 Quantile_AGE_60_Plus Very_Low"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ok7i5dJKC4_R"
},
"source": [
"## Assemble itemsets\n",
"\n",
"1. Create a single \"item\" by concatenating the `feature` and `label` columns. The prefix \"Quantile_\" is also stripped from the feature name in this step.\n",
"\n",
"2. Use the `collect_list` function to assemble the entire collection of single items into an array for each row number."
]
},
{
"cell_type": "code",
"metadata": {
"id": "YoGdk4bwC6u3"
},
"source": [
"itemset_df = long_df_labeled.withColumn('feature', regexp_replace('feature', 'Quantile_', '')) \\\n",
" .withColumn(\"item\", concat(col(\"feature\"), lit(\"_\"), col(\"label\"))) \\\n",
" .drop(\"feature\", \"label\") \\\n",
" .groupBy(\"row_number\") \\\n",
" .agg(collect_list(\"item\").alias(\"items\"))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 202
},
"id": "85-NyGkPs0D8",
"outputId": "f30bf988-9149-4e53-a9e6-5624bd20641a"
},
"source": [
"from pandas import option_context\n",
"\n",
"with option_context('display.max_colwidth', 100):\n",
" display(display_pandas(itemset_df))"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>row_number</th>\n",
" <th>items</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>[trump_pct_Moderately_High, AGE_18_29_Moderately_High, AGE_30_44_Extremely_High, AGE_45_59_Moder...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>[trump_pct_Extremely_High, AGE_18_29_Extremely_High, AGE_30_44_Extremely_High, AGE_45_59_Moderat...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>[trump_pct_Moderately_High, AGE_18_29_Moderately_High, AGE_30_44_Moderately_High, AGE_45_59_Mode...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>[trump_pct_Moderately_High, AGE_18_29_Very_Low, AGE_30_44_Extremely_Low, AGE_45_59_Moderately_Lo...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>[trump_pct_Extremely_High, AGE_18_29_Moderately_Low, AGE_30_44_Moderately_Low, AGE_45_59_Very_Hi...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" row_number items\n",
"0 1 [trump_pct_Moderately_High, AGE_18_29_Moderately_High, AGE_30_44_Extremely_High, AGE_45_59_Moder...\n",
"1 2 [trump_pct_Extremely_High, AGE_18_29_Extremely_High, AGE_30_44_Extremely_High, AGE_45_59_Moderat...\n",
"2 3 [trump_pct_Moderately_High, AGE_18_29_Moderately_High, AGE_30_44_Moderately_High, AGE_45_59_Mode...\n",
"3 4 [trump_pct_Moderately_High, AGE_18_29_Very_Low, AGE_30_44_Extremely_Low, AGE_45_59_Moderately_Lo...\n",
"4 5 [trump_pct_Extremely_High, AGE_18_29_Moderately_Low, AGE_30_44_Moderately_Low, AGE_45_59_Very_Hi..."
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rPILwUmwVHnc"
},
"source": [
"# Association Rules using FP-Growth\n",
"\n",
"1. Initialize FP-Growth and fit to training data\n",
"2. Retrieve itemsets\n",
"3. Explore results\n",
"\n",
"## Initialize `FPGrowth` with parameters and fit to training data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "edmCBNqrc6qW"
},
"source": [
"fpGrowth = FPGrowth(itemsCol=\"items\", minSupport=0.02, minConfidence=0.01)\n",
"\n",
"model = fpGrowth.fit(itemset_df)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "XzyK20_2xPkl"
},
"source": [
"## Retrieve and examine itemsets given a user-defined consequent"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_rlN8q3YENL6"
},
"source": [
"# Retrieve itemsets\n",
"# -- parameters for the maximum number of items and the desired consequent\n",
"# -- results are ordered (descending) by lift score\n",
"def get_itemsets(max_items, consequent):\n",
" out = model.associationRules \\\n",
" .orderBy('lift', ascending=False) \\\n",
" .where(col('lift') > 1) \\\n",
" .where(size(col('antecedent')) == max_items-1) \\\n",
" .where(array_contains(col(\"consequent\"), consequent))\n",
" return(out)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "-f0uZRoCx_JY"
},
"source": [
"In this example, the consequents \"trump_pct_Extremely_Low\" and \"trump_pct_Extremely_High\" are of interest. This will reveal items which are highly correlated with extremely positive and extremely negative support for Trump. For simplicity, itemsets containing only 2 items (a single antecedent and consequent) are retrieved."
]
},
{
"cell_type": "code",
"metadata": {
"id": "7EQ3xfqpxWFf"
},
"source": [
"consequents = [\"trump_pct_Extremely_Low\", \"trump_pct_Extremely_High\"]\n",
"\n",
"itemsets = [get_itemsets(max_items=2, consequent=i) for i in consequents]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "w_6wYyLZxX0G"
},
"source": [
"Itemsets are now converted to dictionary to make them more legible."
]
},
{
"cell_type": "code",
"metadata": {
"id": "O-kvDbf1xVfK"
},
"source": [
"def itemsets_to_json(df):\n",
" itemsets_dict = df.toPandas().to_dict(orient='records')\n",
" return(itemsets_dict)\n",
"\n",
"itemsets_json = [itemsets_to_json(i) for i in itemsets]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "kCNo4JNC5edO"
},
"source": [
"## Exploratory analysis"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ttCbABPt09h5",
"outputId": "d907d240-58fb-45ba-aa6d-e76f12b56c4d"
},
"source": [
"print(\"Sample of antecedents correlated with extremely LOW Trump support:\\n\")\n",
"[i[\"antecedent\"] for i in itemsets_json[0]][0:5]"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Sample of antecedents correlated with extremely low Trump support:\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[['EDU_ATTAIN_Total__Bachelor_s_degree_or_higher_Extremely_High'],\n",
" ['RACE_Total__Asian_alone_Extremely_High'],\n",
" ['EDU_ATTAIN_Total__High_school_graduate__includes_equivalency__Extremely_Low'],\n",
" ['RACE_Total__White_alone_Extremely_Low'],\n",
" ['INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Professional__scientific__and_technical_services_Extremely_High']]"
]
},
"metadata": {
"tags": []
},
"execution_count": 78
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "B4HmBNNt4nkJ",
"outputId": "de4d0df4-400b-48be-81ee-664245b3dc13"
},
"source": [
"print(\"Sample of antecedents correlated with extremely HIGH Trump support:\\n\")\n",
"[i[\"antecedent\"] for i in itemsets_json[1]][0:5]"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Sample of antecedents correlated with extremely HIGH Trump support:\n",
"\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[['INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Agriculture__forestry__fishing_and_hunting_Extremely_High'],\n",
" ['INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Mining__quarrying__and_oil_and_gas_extraction_Extremely_High'],\n",
" ['RACE_Total__Asian_alone_Extremely_Low'],\n",
" ['HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicaid_means_tested_public_coverage_only_Extremely_Low'],\n",
" ['UNEMPLOY_Total_16_YEARS_AND_OVER_Extremely_Low']]"
]
},
"metadata": {
"tags": []
},
"execution_count": 79
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment