Created
May 20, 2021 21:24
-
-
Save dannymorris/65bfd1e920b5c3673d7358cdf9d9753f to your computer and use it in GitHub Desktop.
sparkml-arules.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "sparkml-arules.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"toc_visible": true, | |
"authorship_tag": "ABX9TyN56cncOgVhCAuH2dCCmQki", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/dannymorris/65bfd1e920b5c3673d7358cdf9d9753f/sparkml-arules.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Lhme3LFAYEj_" | |
}, | |
"source": [ | |
"# Applying Association Rules to continuous features using PySpark and Spark MLlib\n", | |
"\n", | |
"## Overview\n", | |
"\n", | |
"This notebook shows how to apply Association Rules to continuous features using PySpark and Spark MLlib's [FP-Growth](https://spark.apache.org/docs/latest/ml-frequent-pattern-mining.html#fp-growth) algorithm. \n", | |
"\n", | |
"Since Association Rules expects \"itemsets\" as inputs (e.g. [milk, eggs]), continuous features in columnar format cannot be directly supplied as input. Thus, transforming continuous features into categorical itemsets is a necessary pre-processing step. In this notebook, the transformation from continuous features to categorical itemsets is done by mapping continuous feature values to discrete deciles (via [QuantileDiscretizer](https://spark.apache.org/docs/latest/ml-features.html#quantilediscretizer)) then assembling the discrete features into itemset vectors for each entity (row) in the original dataset.\n", | |
"\n", | |
"**Visual explanation of the transformation from continuous features to itemsets**\n", | |
"\n", | |
"![Imgur](https://i.imgur.com/wWl7RIvm.png)\n", | |
"\n", | |
"# Workflow\n", | |
"\n", | |
"1. Load CSV file into Spark and read into memory\n", | |
"2. Transform continuous features into binned, categorical features (deciles in this example)\n", | |
"3. Assemble categorical features into itemsets\n", | |
"4. Fit the FP-Growth algorithm to generate association rules\n", | |
"5. Investigate rules" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Uaiz0fDEL1x-" | |
}, | |
"source": [ | |
"# Install PySpark" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "-iOTzB1gd7LE", | |
"outputId": "a9d4f628-86b7-4ceb-c366-5e356246f563" | |
}, | |
"source": [ | |
"!pip install pyspark" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting pyspark\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/b0/9d6860891ab14a39d4bddf80ba26ce51c2f9dc4805e5c6978ac0472c120a/pyspark-3.1.1.tar.gz (212.3MB)\n", | |
"\u001b[K |████████████████████████████████| 212.3MB 67kB/s \n", | |
"\u001b[?25hCollecting py4j==0.10.9\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/9e/b6/6a4fb90cd235dc8e265a6a2067f2a2c99f0d91787f06aca4bcf7c23f3f80/py4j-0.10.9-py2.py3-none-any.whl (198kB)\n", | |
"\u001b[K |████████████████████████████████| 204kB 19.4MB/s \n", | |
"\u001b[?25hBuilding wheels for collected packages: pyspark\n", | |
" Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for pyspark: filename=pyspark-3.1.1-py2.py3-none-any.whl size=212767604 sha256=1ccc2374712de1891cf4e16cf2f3b2379518309bdd3e00217fabc9506f6b4dc9\n", | |
" Stored in directory: /root/.cache/pip/wheels/0b/90/c0/01de724414ef122bd05f056541fb6a0ecf47c7ca655f8b3c0f\n", | |
"Successfully built pyspark\n", | |
"Installing collected packages: py4j, pyspark\n", | |
"Successfully installed py4j-0.10.9 pyspark-3.1.1\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "oeBvkC9UMyok" | |
}, | |
"source": [ | |
"# Load modules and functions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mLUKdc60jkBz" | |
}, | |
"source": [ | |
"from pyspark.sql import SparkSession\n", | |
"from pyspark import SparkFiles\n", | |
"from pyspark.ml import Pipeline\n", | |
"from pyspark.ml.fpm import FPGrowth\n", | |
"from pyspark.ml.feature import QuantileDiscretizer\n", | |
"\n", | |
"from pyspark.sql.functions import row_number, array, col, explode, lit, concat, \\\n", | |
" collect_list, struct, array_contains, create_map, size, regexp_replace\n", | |
" \n", | |
"from pyspark.sql.window import Window\n", | |
"\n", | |
"import pandas as pd\n", | |
"import json\n", | |
"from itertools import chain" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "GAX4QcY6iuNh" | |
}, | |
"source": [ | |
"# general purpose function for obtain a list of column names\n", | |
"# from a dataframe that begin with `pattern`\n", | |
"def list_cols_starting_with(df, pattern):\n", | |
" out = list(filter(lambda x: x.startswith(pattern), df.columns))\n", | |
" return(out)\n", | |
"\n", | |
"# general purpose function to display Spark DF as a pandas DF\n", | |
"# for improved printing\n", | |
"def display_pandas(df, limit=5):\n", | |
" return df.limit(limit).toPandas()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "FMS5JD-kPQsI" | |
}, | |
"source": [ | |
"# Create SparkSession\n", | |
"\n", | |
"`SparkSession` provides a single point of entry to interact with underlying Spark functionality and allows programming Spark with DataFrame and Dataset APIs." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9Gz1Z6ici7KB" | |
}, | |
"source": [ | |
"spark = SparkSession. \\\n", | |
" builder. \\\n", | |
" appName(\"MLlib Clustering\"). \\\n", | |
" getOrCreate()" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "lg_Hnxw8Pgu1" | |
}, | |
"source": [ | |
"# Read data\n", | |
"\n", | |
"In this example, the data resides on the internet in CSV format. Use `addFile` to load the file into Spark and read as a DataFrame using `read.csv`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 222 | |
}, | |
"id": "M3uGR6qaieqF", | |
"outputId": "eb3237ec-a339-4b0b-84cc-e5aa064c588a" | |
}, | |
"source": [ | |
"url = \"https://gist.githubusercontent.com/dannymorris/1bd95ddda1cfe7fd518e5cda01f4ac03/raw/295636f511f0afbd544aece7cbfe771edc01182c/county_data.csv\"\n", | |
"\n", | |
"# upload file to Spark\n", | |
"spark.sparkContext.addFile(url)\n", | |
"\n", | |
"# Read file\n", | |
"df = spark.read.csv(\"file://\"+SparkFiles.get(\"county_data.csv\"), header=True, inferSchema= True)\n", | |
"\n", | |
"display_pandas(df)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>fips</th>\n", | |
" <th>state</th>\n", | |
" <th>name</th>\n", | |
" <th>party_winner</th>\n", | |
" <th>trump_pct</th>\n", | |
" <th>margin</th>\n", | |
" <th>POPULATION_Total</th>\n", | |
" <th>AGE_18_29</th>\n", | |
" <th>AGE_30_44</th>\n", | |
" <th>AGE_45_59</th>\n", | |
" <th>AGE_60_Plus</th>\n", | |
" <th>AGE_18_Plus</th>\n", | |
" <th>RACE_Total__Asian_alone</th>\n", | |
" <th>RACE_Total__Black_or_African_American_alone</th>\n", | |
" <th>RACE_Total__Hispanic_or_Latino</th>\n", | |
" <th>RACE_Total__White_alone</th>\n", | |
" <th>GINI_Gini_Index</th>\n", | |
" <th>INCOME_PER_CAPITA_INCOME_IN_THE_PAST_12_MONTHS__IN_2018_INFLATION_ADJUSTED_DOLLARS_</th>\n", | |
" <th>UNEMPLOY__16_YEARS_AND_OVER__ASIAN_ALONE_</th>\n", | |
" <th>UNEMPLOY__16_YEARS_AND_OVER__BLACK_OR_AFRICAN_AMERICAN_ALONE_</th>\n", | |
" <th>UNEMPLOY__16_YEARS_AND_OVER__HISPANIC_OR_LATINO_</th>\n", | |
" <th>UNEMPLOY__16_YEARS_AND_OVER__WHITE_ALONE_</th>\n", | |
" <th>UNEMPLOY_Total_16_YEARS_AND_OVER</th>\n", | |
" <th>EDU_ATTAIN_Total__Bachelor_s_degree_or_higher</th>\n", | |
" <th>EDU_ATTAIN_Total__High_school_graduate__includes_equivalency_</th>\n", | |
" <th>EDU_ATTAIN_Total__Less_than_high_school_diploma</th>\n", | |
" <th>EDU_ATTAIN_Total__Some_college_or_associate_s_degree</th>\n", | |
" <th>INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Agriculture__forestry__fishing_and_hunting</th>\n", | |
" <th>INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Mining__quarrying__and_oil_and_gas_extraction</th>\n", | |
" <th>INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Accommodation_and_food_services</th>\n", | |
" <th>INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Arts__entertainment__and_recreation</th>\n", | |
" <th>INDUSTRY_Total__Construction</th>\n", | |
" <th>INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Educational_services</th>\n", | |
" <th>INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Health_care_and_social_assistance</th>\n", | |
" <th>INDUSTRY_Total__Information</th>\n", | |
" <th>INDUSTRY_Total__Manufacturing</th>\n", | |
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Administrative_and_support_and_waste_management_services</th>\n", | |
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Management_of_companies_and_enterprises</th>\n", | |
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Professional__scientific__and_technical_services</th>\n", | |
" <th>INDUSTRY_Total__Public_administration</th>\n", | |
" <th>INDUSTRY_Total__Retail_trade</th>\n", | |
" <th>CITIZEN_Estimate__Total__Not_a_U_S__citizen</th>\n", | |
" <th>CITIZEN_Estimate__Total__U_S__citizen_by_naturalization</th>\n", | |
" <th>CITIZEN_Estimate__Total__U_S__citizen__born_in_the_United_States</th>\n", | |
" <th>HEALTH_INSURANCE_No_health_insurance_coverage</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_direct_purchase_health_insurance_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_employer_based_health_insurance_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicaid_means_tested_public_coverage_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicare_coverage_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_TRICARE_military_health_coverage_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_VA_Health_Care_only</th>\n", | |
" <th>VETERANS_Estimate__Total__Veteran</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>48001</td>\n", | |
" <td>texas</td>\n", | |
" <td>Anderson</td>\n", | |
" <td>republican</td>\n", | |
" <td>0.786119</td>\n", | |
" <td>0.580355</td>\n", | |
" <td>57863</td>\n", | |
" <td>0.192527</td>\n", | |
" <td>0.308352</td>\n", | |
" <td>0.246227</td>\n", | |
" <td>0.252894</td>\n", | |
" <td>46648</td>\n", | |
" <td>0.005513</td>\n", | |
" <td>0.209823</td>\n", | |
" <td>0.175276</td>\n", | |
" <td>0.735340</td>\n", | |
" <td>0.4225</td>\n", | |
" <td>16868</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.004287</td>\n", | |
" <td>0.002272</td>\n", | |
" <td>0.007760</td>\n", | |
" <td>0.014320</td>\n", | |
" <td>0.105299</td>\n", | |
" <td>0.359608</td>\n", | |
" <td>0.232636</td>\n", | |
" <td>0.308481</td>\n", | |
" <td>0.007846</td>\n", | |
" <td>0.028490</td>\n", | |
" <td>0.022359</td>\n", | |
" <td>0.002165</td>\n", | |
" <td>0.022380</td>\n", | |
" <td>0.029905</td>\n", | |
" <td>0.057473</td>\n", | |
" <td>0.002037</td>\n", | |
" <td>0.025424</td>\n", | |
" <td>0.021373</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.010912</td>\n", | |
" <td>0.038608</td>\n", | |
" <td>0.077860</td>\n", | |
" <td>0.041546</td>\n", | |
" <td>0.020946</td>\n", | |
" <td>0.928400</td>\n", | |
" <td>0.114324</td>\n", | |
" <td>0.028040</td>\n", | |
" <td>0.306080</td>\n", | |
" <td>0.028040</td>\n", | |
" <td>0.065833</td>\n", | |
" <td>0.005638</td>\n", | |
" <td>0.006088</td>\n", | |
" <td>0.086006</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>48003</td>\n", | |
" <td>texas</td>\n", | |
" <td>Andrews</td>\n", | |
" <td>republican</td>\n", | |
" <td>0.843084</td>\n", | |
" <td>0.698107</td>\n", | |
" <td>17818</td>\n", | |
" <td>0.245792</td>\n", | |
" <td>0.287747</td>\n", | |
" <td>0.246605</td>\n", | |
" <td>0.219855</td>\n", | |
" <td>12299</td>\n", | |
" <td>0.003536</td>\n", | |
" <td>0.019811</td>\n", | |
" <td>0.560052</td>\n", | |
" <td>0.924009</td>\n", | |
" <td>0.4506</td>\n", | |
" <td>31190</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.012196</td>\n", | |
" <td>0.015611</td>\n", | |
" <td>0.018863</td>\n", | |
" <td>0.046670</td>\n", | |
" <td>0.102529</td>\n", | |
" <td>0.461338</td>\n", | |
" <td>0.368810</td>\n", | |
" <td>0.319457</td>\n", | |
" <td>0.008537</td>\n", | |
" <td>0.167900</td>\n", | |
" <td>0.035125</td>\n", | |
" <td>0.003659</td>\n", | |
" <td>0.051955</td>\n", | |
" <td>0.036751</td>\n", | |
" <td>0.065696</td>\n", | |
" <td>0.009513</td>\n", | |
" <td>0.036100</td>\n", | |
" <td>0.022766</td>\n", | |
" <td>0.003252</td>\n", | |
" <td>0.014798</td>\n", | |
" <td>0.013172</td>\n", | |
" <td>0.077323</td>\n", | |
" <td>0.098047</td>\n", | |
" <td>0.047368</td>\n", | |
" <td>0.850208</td>\n", | |
" <td>0.175055</td>\n", | |
" <td>0.043256</td>\n", | |
" <td>0.511342</td>\n", | |
" <td>0.030084</td>\n", | |
" <td>0.055940</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.003334</td>\n", | |
" <td>0.055696</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>48005</td>\n", | |
" <td>texas</td>\n", | |
" <td>Angelina</td>\n", | |
" <td>republican</td>\n", | |
" <td>0.723981</td>\n", | |
" <td>0.460148</td>\n", | |
" <td>87607</td>\n", | |
" <td>0.208537</td>\n", | |
" <td>0.245155</td>\n", | |
" <td>0.258604</td>\n", | |
" <td>0.287704</td>\n", | |
" <td>64914</td>\n", | |
" <td>0.011495</td>\n", | |
" <td>0.147956</td>\n", | |
" <td>0.218864</td>\n", | |
" <td>0.791855</td>\n", | |
" <td>0.4495</td>\n", | |
" <td>22322</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.012416</td>\n", | |
" <td>0.009582</td>\n", | |
" <td>0.032027</td>\n", | |
" <td>0.054025</td>\n", | |
" <td>0.156638</td>\n", | |
" <td>0.314755</td>\n", | |
" <td>0.215177</td>\n", | |
" <td>0.306251</td>\n", | |
" <td>0.011092</td>\n", | |
" <td>0.010768</td>\n", | |
" <td>0.044012</td>\n", | |
" <td>0.005099</td>\n", | |
" <td>0.037681</td>\n", | |
" <td>0.049589</td>\n", | |
" <td>0.092831</td>\n", | |
" <td>0.002927</td>\n", | |
" <td>0.064624</td>\n", | |
" <td>0.024525</td>\n", | |
" <td>0.000308</td>\n", | |
" <td>0.018578</td>\n", | |
" <td>0.023323</td>\n", | |
" <td>0.073836</td>\n", | |
" <td>0.058420</td>\n", | |
" <td>0.026197</td>\n", | |
" <td>0.909745</td>\n", | |
" <td>0.204409</td>\n", | |
" <td>0.058092</td>\n", | |
" <td>0.351111</td>\n", | |
" <td>0.057784</td>\n", | |
" <td>0.073559</td>\n", | |
" <td>0.001833</td>\n", | |
" <td>0.006085</td>\n", | |
" <td>0.091336</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>48007</td>\n", | |
" <td>texas</td>\n", | |
" <td>Aransas</td>\n", | |
" <td>republican</td>\n", | |
" <td>0.751811</td>\n", | |
" <td>0.514525</td>\n", | |
" <td>24763</td>\n", | |
" <td>0.147476</td>\n", | |
" <td>0.171872</td>\n", | |
" <td>0.252445</td>\n", | |
" <td>0.428208</td>\n", | |
" <td>20044</td>\n", | |
" <td>0.019707</td>\n", | |
" <td>0.015386</td>\n", | |
" <td>0.272826</td>\n", | |
" <td>0.892622</td>\n", | |
" <td>0.5351</td>\n", | |
" <td>30939</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.006186</td>\n", | |
" <td>0.031581</td>\n", | |
" <td>0.037767</td>\n", | |
" <td>0.211684</td>\n", | |
" <td>0.299242</td>\n", | |
" <td>0.185242</td>\n", | |
" <td>0.339453</td>\n", | |
" <td>0.005887</td>\n", | |
" <td>0.018010</td>\n", | |
" <td>0.057025</td>\n", | |
" <td>0.023249</td>\n", | |
" <td>0.062862</td>\n", | |
" <td>0.043205</td>\n", | |
" <td>0.052784</td>\n", | |
" <td>0.002644</td>\n", | |
" <td>0.025244</td>\n", | |
" <td>0.019407</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.022051</td>\n", | |
" <td>0.030682</td>\n", | |
" <td>0.054131</td>\n", | |
" <td>0.039979</td>\n", | |
" <td>0.031579</td>\n", | |
" <td>0.914671</td>\n", | |
" <td>0.206945</td>\n", | |
" <td>0.079824</td>\n", | |
" <td>0.252744</td>\n", | |
" <td>0.026542</td>\n", | |
" <td>0.108162</td>\n", | |
" <td>0.007334</td>\n", | |
" <td>0.004690</td>\n", | |
" <td>0.128916</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>48009</td>\n", | |
" <td>texas</td>\n", | |
" <td>Archer</td>\n", | |
" <td>republican</td>\n", | |
" <td>0.896580</td>\n", | |
" <td>0.803586</td>\n", | |
" <td>8789</td>\n", | |
" <td>0.163153</td>\n", | |
" <td>0.208085</td>\n", | |
" <td>0.281809</td>\n", | |
" <td>0.346954</td>\n", | |
" <td>6877</td>\n", | |
" <td>0.005348</td>\n", | |
" <td>0.009216</td>\n", | |
" <td>0.082717</td>\n", | |
" <td>0.948231</td>\n", | |
" <td>0.4316</td>\n", | |
" <td>31806</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.002472</td>\n", | |
" <td>0.020503</td>\n", | |
" <td>0.022975</td>\n", | |
" <td>0.210121</td>\n", | |
" <td>0.305511</td>\n", | |
" <td>0.096990</td>\n", | |
" <td>0.310310</td>\n", | |
" <td>0.034026</td>\n", | |
" <td>0.028210</td>\n", | |
" <td>0.014977</td>\n", | |
" <td>0.008579</td>\n", | |
" <td>0.046968</td>\n", | |
" <td>0.042751</td>\n", | |
" <td>0.123164</td>\n", | |
" <td>0.006107</td>\n", | |
" <td>0.049149</td>\n", | |
" <td>0.021085</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.027774</td>\n", | |
" <td>0.033881</td>\n", | |
" <td>0.080413</td>\n", | |
" <td>0.019229</td>\n", | |
" <td>0.006713</td>\n", | |
" <td>0.971100</td>\n", | |
" <td>0.127963</td>\n", | |
" <td>0.077359</td>\n", | |
" <td>0.422132</td>\n", | |
" <td>0.025447</td>\n", | |
" <td>0.073724</td>\n", | |
" <td>0.011778</td>\n", | |
" <td>0.000145</td>\n", | |
" <td>0.090883</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" fips ... VETERANS_Estimate__Total__Veteran\n", | |
"0 48001 ... 0.086006\n", | |
"1 48003 ... 0.055696\n", | |
"2 48005 ... 0.091336\n", | |
"3 48007 ... 0.128916\n", | |
"4 48009 ... 0.090883\n", | |
"\n", | |
"[5 rows x 52 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 27 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "1bcZqbGVQJYa" | |
}, | |
"source": [ | |
"# Select features for Association Rules mining\n", | |
"\n", | |
"To avoid handcrafting the entire feature list, features with certain prefixes (tags) in the feature name are selected from the original DataFrame." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 222 | |
}, | |
"id": "UDMlKUGrkLjU", | |
"outputId": "8e181157-414d-47a9-9e81-3a9dd45aa67e" | |
}, | |
"source": [ | |
"# locate columns whose names begin with one of these prefixes (tags)\n", | |
"feature_tags = ('trump', 'AGE', 'RACE', 'GINI', 'INCOME', 'UNEMPLOY', 'EDU', 'INDUSTRY', 'CITIZEN', 'HEALTH', 'VETERANS')\n", | |
"feature_list = list_cols_starting_with(df=df, pattern=feature_tags)\n", | |
"\n", | |
"# drop unused colum AGE_18_Plus\n", | |
"feature_list = [i for i in feature_list if i != \"AGE_18_Plus\"]\n", | |
"\n", | |
"# select features\n", | |
"features_df = df.select(feature_list)\n", | |
"\n", | |
"display_pandas(features_df)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>trump_pct</th>\n", | |
" <th>AGE_18_29</th>\n", | |
" <th>AGE_30_44</th>\n", | |
" <th>AGE_45_59</th>\n", | |
" <th>AGE_60_Plus</th>\n", | |
" <th>RACE_Total__Asian_alone</th>\n", | |
" <th>RACE_Total__Black_or_African_American_alone</th>\n", | |
" <th>RACE_Total__Hispanic_or_Latino</th>\n", | |
" <th>RACE_Total__White_alone</th>\n", | |
" <th>GINI_Gini_Index</th>\n", | |
" <th>INCOME_PER_CAPITA_INCOME_IN_THE_PAST_12_MONTHS__IN_2018_INFLATION_ADJUSTED_DOLLARS_</th>\n", | |
" <th>UNEMPLOY__16_YEARS_AND_OVER__ASIAN_ALONE_</th>\n", | |
" <th>UNEMPLOY__16_YEARS_AND_OVER__BLACK_OR_AFRICAN_AMERICAN_ALONE_</th>\n", | |
" <th>UNEMPLOY__16_YEARS_AND_OVER__HISPANIC_OR_LATINO_</th>\n", | |
" <th>UNEMPLOY__16_YEARS_AND_OVER__WHITE_ALONE_</th>\n", | |
" <th>UNEMPLOY_Total_16_YEARS_AND_OVER</th>\n", | |
" <th>EDU_ATTAIN_Total__Bachelor_s_degree_or_higher</th>\n", | |
" <th>EDU_ATTAIN_Total__High_school_graduate__includes_equivalency_</th>\n", | |
" <th>EDU_ATTAIN_Total__Less_than_high_school_diploma</th>\n", | |
" <th>EDU_ATTAIN_Total__Some_college_or_associate_s_degree</th>\n", | |
" <th>INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Agriculture__forestry__fishing_and_hunting</th>\n", | |
" <th>INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Mining__quarrying__and_oil_and_gas_extraction</th>\n", | |
" <th>INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Accommodation_and_food_services</th>\n", | |
" <th>INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Arts__entertainment__and_recreation</th>\n", | |
" <th>INDUSTRY_Total__Construction</th>\n", | |
" <th>INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Educational_services</th>\n", | |
" <th>INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Health_care_and_social_assistance</th>\n", | |
" <th>INDUSTRY_Total__Information</th>\n", | |
" <th>INDUSTRY_Total__Manufacturing</th>\n", | |
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Administrative_and_support_and_waste_management_services</th>\n", | |
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Management_of_companies_and_enterprises</th>\n", | |
" <th>INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Professional__scientific__and_technical_services</th>\n", | |
" <th>INDUSTRY_Total__Public_administration</th>\n", | |
" <th>INDUSTRY_Total__Retail_trade</th>\n", | |
" <th>CITIZEN_Estimate__Total__Not_a_U_S__citizen</th>\n", | |
" <th>CITIZEN_Estimate__Total__U_S__citizen_by_naturalization</th>\n", | |
" <th>CITIZEN_Estimate__Total__U_S__citizen__born_in_the_United_States</th>\n", | |
" <th>HEALTH_INSURANCE_No_health_insurance_coverage</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_direct_purchase_health_insurance_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_employer_based_health_insurance_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicaid_means_tested_public_coverage_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicare_coverage_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_TRICARE_military_health_coverage_only</th>\n", | |
" <th>HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_VA_Health_Care_only</th>\n", | |
" <th>VETERANS_Estimate__Total__Veteran</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0.786119</td>\n", | |
" <td>0.192527</td>\n", | |
" <td>0.308352</td>\n", | |
" <td>0.246227</td>\n", | |
" <td>0.252894</td>\n", | |
" <td>0.005513</td>\n", | |
" <td>0.209823</td>\n", | |
" <td>0.175276</td>\n", | |
" <td>0.735340</td>\n", | |
" <td>0.4225</td>\n", | |
" <td>16868</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.004287</td>\n", | |
" <td>0.002272</td>\n", | |
" <td>0.007760</td>\n", | |
" <td>0.014320</td>\n", | |
" <td>0.105299</td>\n", | |
" <td>0.359608</td>\n", | |
" <td>0.232636</td>\n", | |
" <td>0.308481</td>\n", | |
" <td>0.007846</td>\n", | |
" <td>0.028490</td>\n", | |
" <td>0.022359</td>\n", | |
" <td>0.002165</td>\n", | |
" <td>0.022380</td>\n", | |
" <td>0.029905</td>\n", | |
" <td>0.057473</td>\n", | |
" <td>0.002037</td>\n", | |
" <td>0.025424</td>\n", | |
" <td>0.021373</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.010912</td>\n", | |
" <td>0.038608</td>\n", | |
" <td>0.077860</td>\n", | |
" <td>0.041546</td>\n", | |
" <td>0.020946</td>\n", | |
" <td>0.928400</td>\n", | |
" <td>0.114324</td>\n", | |
" <td>0.028040</td>\n", | |
" <td>0.306080</td>\n", | |
" <td>0.028040</td>\n", | |
" <td>0.065833</td>\n", | |
" <td>0.005638</td>\n", | |
" <td>0.006088</td>\n", | |
" <td>0.086006</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0.843084</td>\n", | |
" <td>0.245792</td>\n", | |
" <td>0.287747</td>\n", | |
" <td>0.246605</td>\n", | |
" <td>0.219855</td>\n", | |
" <td>0.003536</td>\n", | |
" <td>0.019811</td>\n", | |
" <td>0.560052</td>\n", | |
" <td>0.924009</td>\n", | |
" <td>0.4506</td>\n", | |
" <td>31190</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.012196</td>\n", | |
" <td>0.015611</td>\n", | |
" <td>0.018863</td>\n", | |
" <td>0.046670</td>\n", | |
" <td>0.102529</td>\n", | |
" <td>0.461338</td>\n", | |
" <td>0.368810</td>\n", | |
" <td>0.319457</td>\n", | |
" <td>0.008537</td>\n", | |
" <td>0.167900</td>\n", | |
" <td>0.035125</td>\n", | |
" <td>0.003659</td>\n", | |
" <td>0.051955</td>\n", | |
" <td>0.036751</td>\n", | |
" <td>0.065696</td>\n", | |
" <td>0.009513</td>\n", | |
" <td>0.036100</td>\n", | |
" <td>0.022766</td>\n", | |
" <td>0.003252</td>\n", | |
" <td>0.014798</td>\n", | |
" <td>0.013172</td>\n", | |
" <td>0.077323</td>\n", | |
" <td>0.098047</td>\n", | |
" <td>0.047368</td>\n", | |
" <td>0.850208</td>\n", | |
" <td>0.175055</td>\n", | |
" <td>0.043256</td>\n", | |
" <td>0.511342</td>\n", | |
" <td>0.030084</td>\n", | |
" <td>0.055940</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.003334</td>\n", | |
" <td>0.055696</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0.723981</td>\n", | |
" <td>0.208537</td>\n", | |
" <td>0.245155</td>\n", | |
" <td>0.258604</td>\n", | |
" <td>0.287704</td>\n", | |
" <td>0.011495</td>\n", | |
" <td>0.147956</td>\n", | |
" <td>0.218864</td>\n", | |
" <td>0.791855</td>\n", | |
" <td>0.4495</td>\n", | |
" <td>22322</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.012416</td>\n", | |
" <td>0.009582</td>\n", | |
" <td>0.032027</td>\n", | |
" <td>0.054025</td>\n", | |
" <td>0.156638</td>\n", | |
" <td>0.314755</td>\n", | |
" <td>0.215177</td>\n", | |
" <td>0.306251</td>\n", | |
" <td>0.011092</td>\n", | |
" <td>0.010768</td>\n", | |
" <td>0.044012</td>\n", | |
" <td>0.005099</td>\n", | |
" <td>0.037681</td>\n", | |
" <td>0.049589</td>\n", | |
" <td>0.092831</td>\n", | |
" <td>0.002927</td>\n", | |
" <td>0.064624</td>\n", | |
" <td>0.024525</td>\n", | |
" <td>0.000308</td>\n", | |
" <td>0.018578</td>\n", | |
" <td>0.023323</td>\n", | |
" <td>0.073836</td>\n", | |
" <td>0.058420</td>\n", | |
" <td>0.026197</td>\n", | |
" <td>0.909745</td>\n", | |
" <td>0.204409</td>\n", | |
" <td>0.058092</td>\n", | |
" <td>0.351111</td>\n", | |
" <td>0.057784</td>\n", | |
" <td>0.073559</td>\n", | |
" <td>0.001833</td>\n", | |
" <td>0.006085</td>\n", | |
" <td>0.091336</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0.751811</td>\n", | |
" <td>0.147476</td>\n", | |
" <td>0.171872</td>\n", | |
" <td>0.252445</td>\n", | |
" <td>0.428208</td>\n", | |
" <td>0.019707</td>\n", | |
" <td>0.015386</td>\n", | |
" <td>0.272826</td>\n", | |
" <td>0.892622</td>\n", | |
" <td>0.5351</td>\n", | |
" <td>30939</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.006186</td>\n", | |
" <td>0.031581</td>\n", | |
" <td>0.037767</td>\n", | |
" <td>0.211684</td>\n", | |
" <td>0.299242</td>\n", | |
" <td>0.185242</td>\n", | |
" <td>0.339453</td>\n", | |
" <td>0.005887</td>\n", | |
" <td>0.018010</td>\n", | |
" <td>0.057025</td>\n", | |
" <td>0.023249</td>\n", | |
" <td>0.062862</td>\n", | |
" <td>0.043205</td>\n", | |
" <td>0.052784</td>\n", | |
" <td>0.002644</td>\n", | |
" <td>0.025244</td>\n", | |
" <td>0.019407</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.022051</td>\n", | |
" <td>0.030682</td>\n", | |
" <td>0.054131</td>\n", | |
" <td>0.039979</td>\n", | |
" <td>0.031579</td>\n", | |
" <td>0.914671</td>\n", | |
" <td>0.206945</td>\n", | |
" <td>0.079824</td>\n", | |
" <td>0.252744</td>\n", | |
" <td>0.026542</td>\n", | |
" <td>0.108162</td>\n", | |
" <td>0.007334</td>\n", | |
" <td>0.004690</td>\n", | |
" <td>0.128916</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0.896580</td>\n", | |
" <td>0.163153</td>\n", | |
" <td>0.208085</td>\n", | |
" <td>0.281809</td>\n", | |
" <td>0.346954</td>\n", | |
" <td>0.005348</td>\n", | |
" <td>0.009216</td>\n", | |
" <td>0.082717</td>\n", | |
" <td>0.948231</td>\n", | |
" <td>0.4316</td>\n", | |
" <td>31806</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.002472</td>\n", | |
" <td>0.020503</td>\n", | |
" <td>0.022975</td>\n", | |
" <td>0.210121</td>\n", | |
" <td>0.305511</td>\n", | |
" <td>0.096990</td>\n", | |
" <td>0.310310</td>\n", | |
" <td>0.034026</td>\n", | |
" <td>0.028210</td>\n", | |
" <td>0.014977</td>\n", | |
" <td>0.008579</td>\n", | |
" <td>0.046968</td>\n", | |
" <td>0.042751</td>\n", | |
" <td>0.123164</td>\n", | |
" <td>0.006107</td>\n", | |
" <td>0.049149</td>\n", | |
" <td>0.021085</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.027774</td>\n", | |
" <td>0.033881</td>\n", | |
" <td>0.080413</td>\n", | |
" <td>0.019229</td>\n", | |
" <td>0.006713</td>\n", | |
" <td>0.971100</td>\n", | |
" <td>0.127963</td>\n", | |
" <td>0.077359</td>\n", | |
" <td>0.422132</td>\n", | |
" <td>0.025447</td>\n", | |
" <td>0.073724</td>\n", | |
" <td>0.011778</td>\n", | |
" <td>0.000145</td>\n", | |
" <td>0.090883</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" trump_pct ... VETERANS_Estimate__Total__Veteran\n", | |
"0 0.786119 ... 0.086006\n", | |
"1 0.843084 ... 0.055696\n", | |
"2 0.723981 ... 0.091336\n", | |
"3 0.751811 ... 0.128916\n", | |
"4 0.896580 ... 0.090883\n", | |
"\n", | |
"[5 rows x 45 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 28 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rD1yc70LQcsf" | |
}, | |
"source": [ | |
"# Transform continuous features to categorical\n", | |
"\n", | |
"[QuantileDiscretizer](https://spark.apache.org/docs/latest/ml-features.html#quantilediscretizer) maps continuous features values to discrete (categorical) deciles.\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-jRW-GzIQfSI", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 222 | |
}, | |
"outputId": "68a5ba3f-fe62-4e95-913f-e56b576dfb75" | |
}, | |
"source": [ | |
"# fit QuantileDiscretizer to all and transform\n", | |
"discretizer = [QuantileDiscretizer(inputCol=x, outputCol=\"Quantile_\"+x, numBuckets=10) for x in feature_list]\n", | |
"discretizer_results = Pipeline(stages=discretizer).fit(features_df).transform(features_df)\n", | |
"\n", | |
"# select transformed columns\n", | |
"discrete_cols = list_cols_starting_with(discretizer_results, \"Quantile_\")\n", | |
"discrete_df = discretizer_results.select(discrete_cols)\n", | |
"\n", | |
"display_pandas(discrete_df)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Quantile_trump_pct</th>\n", | |
" <th>Quantile_AGE_18_29</th>\n", | |
" <th>Quantile_AGE_30_44</th>\n", | |
" <th>Quantile_AGE_45_59</th>\n", | |
" <th>Quantile_AGE_60_Plus</th>\n", | |
" <th>Quantile_RACE_Total__Asian_alone</th>\n", | |
" <th>Quantile_RACE_Total__Black_or_African_American_alone</th>\n", | |
" <th>Quantile_RACE_Total__Hispanic_or_Latino</th>\n", | |
" <th>Quantile_RACE_Total__White_alone</th>\n", | |
" <th>Quantile_GINI_Gini_Index</th>\n", | |
" <th>Quantile_INCOME_PER_CAPITA_INCOME_IN_THE_PAST_12_MONTHS__IN_2018_INFLATION_ADJUSTED_DOLLARS_</th>\n", | |
" <th>Quantile_UNEMPLOY__16_YEARS_AND_OVER__ASIAN_ALONE_</th>\n", | |
" <th>Quantile_UNEMPLOY__16_YEARS_AND_OVER__BLACK_OR_AFRICAN_AMERICAN_ALONE_</th>\n", | |
" <th>Quantile_UNEMPLOY__16_YEARS_AND_OVER__HISPANIC_OR_LATINO_</th>\n", | |
" <th>Quantile_UNEMPLOY__16_YEARS_AND_OVER__WHITE_ALONE_</th>\n", | |
" <th>Quantile_UNEMPLOY_Total_16_YEARS_AND_OVER</th>\n", | |
" <th>Quantile_EDU_ATTAIN_Total__Bachelor_s_degree_or_higher</th>\n", | |
" <th>Quantile_EDU_ATTAIN_Total__High_school_graduate__includes_equivalency_</th>\n", | |
" <th>Quantile_EDU_ATTAIN_Total__Less_than_high_school_diploma</th>\n", | |
" <th>Quantile_EDU_ATTAIN_Total__Some_college_or_associate_s_degree</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Agriculture__forestry__fishing_and_hunting</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Mining__quarrying__and_oil_and_gas_extraction</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Accommodation_and_food_services</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Arts__entertainment__and_recreation__and_accommodation_and_food_services__Arts__entertainment__and_recreation</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Construction</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Educational_services</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Educational_services__and_health_care_and_social_assistance__Health_care_and_social_assistance</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Information</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Manufacturing</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Administrative_and_support_and_waste_management_services</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Management_of_companies_and_enterprises</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Professional__scientific__and_technical_services</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Public_administration</th>\n", | |
" <th>Quantile_INDUSTRY_Total__Retail_trade</th>\n", | |
" <th>Quantile_CITIZEN_Estimate__Total__Not_a_U_S__citizen</th>\n", | |
" <th>Quantile_CITIZEN_Estimate__Total__U_S__citizen_by_naturalization</th>\n", | |
" <th>Quantile_CITIZEN_Estimate__Total__U_S__citizen__born_in_the_United_States</th>\n", | |
" <th>Quantile_HEALTH_INSURANCE_No_health_insurance_coverage</th>\n", | |
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_direct_purchase_health_insurance_only</th>\n", | |
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_employer_based_health_insurance_only</th>\n", | |
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicaid_means_tested_public_coverage_only</th>\n", | |
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicare_coverage_only</th>\n", | |
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_TRICARE_military_health_coverage_only</th>\n", | |
" <th>Quantile_HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_VA_Health_Care_only</th>\n", | |
" <th>Quantile_VETERANS_Estimate__Total__Veteran</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>7.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>4.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>9.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>0.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>6.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>6.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>6.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>9.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>9.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>2.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>4.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>9.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>6.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>5.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>5.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Quantile_trump_pct ... Quantile_VETERANS_Estimate__Total__Veteran\n", | |
"0 7.0 ... 4.0\n", | |
"1 9.0 ... 0.0\n", | |
"2 6.0 ... 6.0\n", | |
"3 6.0 ... 9.0\n", | |
"4 9.0 ... 5.0\n", | |
"\n", | |
"[5 rows x 45 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 29 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "EhDjLSVwWCdS" | |
}, | |
"source": [ | |
"# Assemble discrete features into itemsets\n", | |
"\n", | |
"1. Reshape features from wide to long format.\n", | |
"2. Create individual items by concatenating the feature name with the discrete feature value.\n", | |
"3. Assemble the itemsets for each row by collecting the individual items into an array." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "lLgKqzB-nLtc" | |
}, | |
"source": [ | |
"## Reshape features from wide to long format\n", | |
"\n", | |
"The function `to_explode` casts a DataFrame from wide to long. This function behaves similarly to `tidyr::gather()` in R and `.pivot` in Python." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3b_D0e_DIwMR" | |
}, | |
"source": [ | |
"def to_explode(df, by):\n", | |
" cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))\n", | |
" kvs = explode(array([\n", | |
" struct(lit(c).alias(\"feature\"), col(c).alias(\"value\")) for c in cols\n", | |
" ])).alias(\"kvs\")\n", | |
" return df.select(by + [kvs]).select(by + [\"kvs.feature\", \"kvs.value\"])" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "I5XZ28IaLljL" | |
}, | |
"source": [ | |
"To retain the row-wise relationships after the data is reshaped, a column containing the row number is appended using the custom `append_row_number` function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "uBezhWy_k9Mr" | |
}, | |
"source": [ | |
"def append_row_number(df):\n", | |
" w = Window().orderBy(lit('A'))\n", | |
" out = df.withColumn(\"row_number\", row_number().over(w))\n", | |
" return(out)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ba4qbQq6oFkT" | |
}, | |
"source": [ | |
"Putting it all together with a custom `wide_to_long` function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 202 | |
}, | |
"id": "ziTAUwrhmPfo", | |
"outputId": "4d898be8-5810-466a-f74b-2686054d937c" | |
}, | |
"source": [ | |
"# cast from wide to long\n", | |
"def wide_to_long(df):\n", | |
" df_numbered = append_row_number(df)\n", | |
" df_long = to_explode(df_numbered, by=['row_number'])\n", | |
" return(df_long)\n", | |
"\n", | |
"long_df = wide_to_long(df = discrete_df)\n", | |
"\n", | |
"display_pandas(long_df)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>row_number</th>\n", | |
" <th>feature</th>\n", | |
" <th>value</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_trump_pct</td>\n", | |
" <td>7.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_AGE_18_29</td>\n", | |
" <td>6.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_AGE_30_44</td>\n", | |
" <td>9.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_AGE_45_59</td>\n", | |
" <td>2.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_AGE_60_Plus</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" row_number feature value\n", | |
"0 1 Quantile_trump_pct 7.0\n", | |
"1 1 Quantile_AGE_18_29 6.0\n", | |
"2 1 Quantile_AGE_30_44 9.0\n", | |
"3 1 Quantile_AGE_45_59 2.0\n", | |
"4 1 Quantile_AGE_60_Plus 1.0" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 30 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "cksmYdfDJnGc" | |
}, | |
"source": [ | |
"## Convert decile integers to intuitive labels\n", | |
"\n", | |
"To improve the interpretability of the final solution, discrete decile values (e.g. `1.0`) are transformed to intuitive labels (e.g. `Extremeley Low`)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 202 | |
}, | |
"id": "RB6Ag9Uj3_n2", | |
"outputId": "c61a0def-91bf-43cf-bd22-de10c2263450" | |
}, | |
"source": [ | |
"mapping = {\n", | |
" 0.0: 'Extremely_Low',\n", | |
" 1.0: 'Very_Low',\n", | |
" 2.0: 'Moderately_Low',\n", | |
" 3.0: 'Moderately_Low',\n", | |
" 4.0: 'Moderate',\n", | |
" 5.0: 'Moderate',\n", | |
" 6.0: 'Moderately_High',\n", | |
" 7.0: 'Moderately_High',\n", | |
" 8.0: 'Very_High',\n", | |
" 9.0: 'Extremely_High'\n", | |
"}\n", | |
"\n", | |
"def int_to_labels(column, mapping):\n", | |
" mapping_expr = create_map([lit(x) for x in chain(*mapping.items())])\n", | |
" labels = mapping_expr.getItem(col(column))\n", | |
" return(labels)\n", | |
"\n", | |
"long_df_labeled = long_df.withColumn(\"label\", int_to_labels(\"value\", mapping)).drop(\"value\")\n", | |
"\n", | |
"def display_pandas(df, limit=5):\n", | |
" return df.limit(limit).toPandas()\n", | |
"\n", | |
"display_pandas(long_df_labeled, limit=5)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>row_number</th>\n", | |
" <th>feature</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_trump_pct</td>\n", | |
" <td>Moderately_High</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_AGE_18_29</td>\n", | |
" <td>Moderately_High</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_AGE_30_44</td>\n", | |
" <td>Extremely_High</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_AGE_45_59</td>\n", | |
" <td>Moderately_Low</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1</td>\n", | |
" <td>Quantile_AGE_60_Plus</td>\n", | |
" <td>Very_Low</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" row_number feature label\n", | |
"0 1 Quantile_trump_pct Moderately_High\n", | |
"1 1 Quantile_AGE_18_29 Moderately_High\n", | |
"2 1 Quantile_AGE_30_44 Extremely_High\n", | |
"3 1 Quantile_AGE_45_59 Moderately_Low\n", | |
"4 1 Quantile_AGE_60_Plus Very_Low" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 25 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Ok7i5dJKC4_R" | |
}, | |
"source": [ | |
"## Assemble itemsets\n", | |
"\n", | |
"1. Create a single \"item\" by concatenating the `feature` and `label` columns. The prefix \"Quantile_\" is also stripped from the feature name in this step.\n", | |
"\n", | |
"2. Use the `collect_list` function to assemble the entire collection of single items into an array for each row number." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YoGdk4bwC6u3" | |
}, | |
"source": [ | |
"itemset_df = long_df_labeled.withColumn('feature', regexp_replace('feature', 'Quantile_', '')) \\\n", | |
" .withColumn(\"item\", concat(col(\"feature\"), lit(\"_\"), col(\"label\"))) \\\n", | |
" .drop(\"feature\", \"label\") \\\n", | |
" .groupBy(\"row_number\") \\\n", | |
" .agg(collect_list(\"item\").alias(\"items\"))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 202 | |
}, | |
"id": "85-NyGkPs0D8", | |
"outputId": "f30bf988-9149-4e53-a9e6-5624bd20641a" | |
}, | |
"source": [ | |
"from pandas import option_context\n", | |
"\n", | |
"with option_context('display.max_colwidth', 100):\n", | |
" display(display_pandas(itemset_df))" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>row_number</th>\n", | |
" <th>items</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>[trump_pct_Moderately_High, AGE_18_29_Moderately_High, AGE_30_44_Extremely_High, AGE_45_59_Moder...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>2</td>\n", | |
" <td>[trump_pct_Extremely_High, AGE_18_29_Extremely_High, AGE_30_44_Extremely_High, AGE_45_59_Moderat...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>3</td>\n", | |
" <td>[trump_pct_Moderately_High, AGE_18_29_Moderately_High, AGE_30_44_Moderately_High, AGE_45_59_Mode...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>4</td>\n", | |
" <td>[trump_pct_Moderately_High, AGE_18_29_Very_Low, AGE_30_44_Extremely_Low, AGE_45_59_Moderately_Lo...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5</td>\n", | |
" <td>[trump_pct_Extremely_High, AGE_18_29_Moderately_Low, AGE_30_44_Moderately_Low, AGE_45_59_Very_Hi...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" row_number items\n", | |
"0 1 [trump_pct_Moderately_High, AGE_18_29_Moderately_High, AGE_30_44_Extremely_High, AGE_45_59_Moder...\n", | |
"1 2 [trump_pct_Extremely_High, AGE_18_29_Extremely_High, AGE_30_44_Extremely_High, AGE_45_59_Moderat...\n", | |
"2 3 [trump_pct_Moderately_High, AGE_18_29_Moderately_High, AGE_30_44_Moderately_High, AGE_45_59_Mode...\n", | |
"3 4 [trump_pct_Moderately_High, AGE_18_29_Very_Low, AGE_30_44_Extremely_Low, AGE_45_59_Moderately_Lo...\n", | |
"4 5 [trump_pct_Extremely_High, AGE_18_29_Moderately_Low, AGE_30_44_Moderately_Low, AGE_45_59_Very_Hi..." | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rPILwUmwVHnc" | |
}, | |
"source": [ | |
"# Association Rules using FP-Growth\n", | |
"\n", | |
"1. Initialize FP-Growth and fit to training data\n", | |
"2. Retrieve itemsets\n", | |
"3. Explore results\n", | |
"\n", | |
"## Initialize `FPGrowth` with parameters and fit to training data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "edmCBNqrc6qW" | |
}, | |
"source": [ | |
"fpGrowth = FPGrowth(itemsCol=\"items\", minSupport=0.02, minConfidence=0.01)\n", | |
"\n", | |
"model = fpGrowth.fit(itemset_df)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XzyK20_2xPkl" | |
}, | |
"source": [ | |
"## Retrieve and examine itemsets given a user-defined consequent" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_rlN8q3YENL6" | |
}, | |
"source": [ | |
"# Retrieve itemsets\n", | |
"# -- parameters for the maximum number of items and the desired consequent\n", | |
"# -- results are ordered (descending) by lift score\n", | |
"def get_itemsets(max_items, consequent):\n", | |
" out = model.associationRules \\\n", | |
" .orderBy('lift', ascending=False) \\\n", | |
" .where(col('lift') > 1) \\\n", | |
" .where(size(col('antecedent')) == max_items-1) \\\n", | |
" .where(array_contains(col(\"consequent\"), consequent))\n", | |
" return(out)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-f0uZRoCx_JY" | |
}, | |
"source": [ | |
"In this example, the consequents \"trump_pct_Extremely_Low\" and \"trump_pct_Extremely_High\" are of interest. This will reveal items which are highly correlated with extremely positive and extremely negative support for Trump. For simplicity, itemsets containing only 2 items (a single antecedent and consequent) are retrieved." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7EQ3xfqpxWFf" | |
}, | |
"source": [ | |
"consequents = [\"trump_pct_Extremely_Low\", \"trump_pct_Extremely_High\"]\n", | |
"\n", | |
"itemsets = [get_itemsets(max_items=2, consequent=i) for i in consequents]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "w_6wYyLZxX0G" | |
}, | |
"source": [ | |
"Itemsets are now converted to dictionary to make them more legible." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O-kvDbf1xVfK" | |
}, | |
"source": [ | |
"def itemsets_to_json(df):\n", | |
" itemsets_dict = df.toPandas().to_dict(orient='records')\n", | |
" return(itemsets_dict)\n", | |
"\n", | |
"itemsets_json = [itemsets_to_json(i) for i in itemsets]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "kCNo4JNC5edO" | |
}, | |
"source": [ | |
"## Exploratory analysis" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ttCbABPt09h5", | |
"outputId": "d907d240-58fb-45ba-aa6d-e76f12b56c4d" | |
}, | |
"source": [ | |
"print(\"Sample of antecedents correlated with extremely LOW Trump support:\\n\")\n", | |
"[i[\"antecedent\"] for i in itemsets_json[0]][0:5]" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Sample of antecedents correlated with extremely low Trump support:\n", | |
"\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[['EDU_ATTAIN_Total__Bachelor_s_degree_or_higher_Extremely_High'],\n", | |
" ['RACE_Total__Asian_alone_Extremely_High'],\n", | |
" ['EDU_ATTAIN_Total__High_school_graduate__includes_equivalency__Extremely_Low'],\n", | |
" ['RACE_Total__White_alone_Extremely_Low'],\n", | |
" ['INDUSTRY_Total__Professional__scientific__and_management__and_administrative__and_waste_management_services__Professional__scientific__and_technical_services_Extremely_High']]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 78 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "B4HmBNNt4nkJ", | |
"outputId": "de4d0df4-400b-48be-81ee-664245b3dc13" | |
}, | |
"source": [ | |
"print(\"Sample of antecedents correlated with extremely HIGH Trump support:\\n\")\n", | |
"[i[\"antecedent\"] for i in itemsets_json[1]][0:5]" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Sample of antecedents correlated with extremely HIGH Trump support:\n", | |
"\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[['INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Agriculture__forestry__fishing_and_hunting_Extremely_High'],\n", | |
" ['INDUSTRY_Total__Agriculture__forestry__fishing_and_hunting__and_mining__Mining__quarrying__and_oil_and_gas_extraction_Extremely_High'],\n", | |
" ['RACE_Total__Asian_alone_Extremely_Low'],\n", | |
" ['HEALTH_INSURANCE_With_one_type_of_health_insurance_coverageWith_Medicaid_means_tested_public_coverage_only_Extremely_Low'],\n", | |
" ['UNEMPLOY_Total_16_YEARS_AND_OVER_Extremely_Low']]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 79 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment