Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shravankumar147/bebd7b387aa6e87dd60ec9e1dafa77df to your computer and use it in GitHub Desktop.
Save shravankumar147/bebd7b387aa6e87dd60ec9e1dafa77df to your computer and use it in GitHub Desktop.
pyspark_on_colab_with Decision Tree Example.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "pyspark_on_colab_with Decision Tree Example.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shravankumar147/bebd7b387aa6e87dd60ec9e1dafa77df/pyspark_on_colab_with-decision-tree-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LoXaoJQa_2Jf"
},
"source": [
"**Pyspark on Google Colab**\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kzzdN24a_9t5"
},
"source": [
"Check python version"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4dLWEvSK7IS_",
"outputId": "bd28bf4c-b39b-476c-8af9-e90fb9667f79",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!which python"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/bin/python\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GvIpI_yTAFEy"
},
"source": [
"Check list of available packages"
]
},
{
"cell_type": "code",
"metadata": {
"id": "liT3Yrvi7Pwe",
"outputId": "1e098b61-cc71-4aea-d976-5518d96bfaed",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!pip freeze"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"absl-py==0.10.0\n",
"alabaster==0.7.12\n",
"albumentations==0.1.12\n",
"altair==4.1.0\n",
"argon2-cffi==20.1.0\n",
"asgiref==3.3.0\n",
"astor==0.8.1\n",
"astropy==4.1\n",
"astunparse==1.6.3\n",
"async-generator==1.10\n",
"atari-py==0.2.6\n",
"atomicwrites==1.4.0\n",
"attrs==20.2.0\n",
"audioread==2.1.9\n",
"autograd==1.3\n",
"Babel==2.8.0\n",
"backcall==0.2.0\n",
"beautifulsoup4==4.6.3\n",
"bleach==3.2.1\n",
"blis==0.4.1\n",
"bokeh==2.1.1\n",
"Bottleneck==1.3.2\n",
"branca==0.4.1\n",
"bs4==0.0.1\n",
"CacheControl==0.12.6\n",
"cachetools==4.1.1\n",
"catalogue==1.0.0\n",
"certifi==2020.6.20\n",
"cffi==1.14.3\n",
"chainer==7.4.0\n",
"chardet==3.0.4\n",
"click==7.1.2\n",
"cloudpickle==1.3.0\n",
"cmake==3.12.0\n",
"cmdstanpy==0.9.5\n",
"colorlover==0.3.0\n",
"community==1.0.0b1\n",
"contextlib2==0.5.5\n",
"convertdate==2.2.2\n",
"coverage==3.7.1\n",
"coveralls==0.5\n",
"crcmod==1.7\n",
"cufflinks==0.17.3\n",
"cupy-cuda101==7.4.0\n",
"cvxopt==1.2.5\n",
"cvxpy==1.0.31\n",
"cycler==0.10.0\n",
"cymem==2.0.4\n",
"Cython==0.29.21\n",
"daft==0.0.4\n",
"dask==2.12.0\n",
"dataclasses==0.7\n",
"datascience==0.10.6\n",
"debugpy==1.0.0\n",
"decorator==4.4.2\n",
"defusedxml==0.6.0\n",
"descartes==1.1.0\n",
"dill==0.3.3\n",
"distributed==1.25.3\n",
"Django==3.1.3\n",
"dlib==19.18.0\n",
"dm-tree==0.1.5\n",
"docopt==0.6.2\n",
"docutils==0.16\n",
"dopamine-rl==1.0.5\n",
"earthengine-api==0.1.238\n",
"easydict==1.9\n",
"ecos==2.0.7.post1\n",
"editdistance==0.5.3\n",
"en-core-web-sm==2.2.5\n",
"entrypoints==0.3\n",
"ephem==3.7.7.1\n",
"et-xmlfile==1.0.1\n",
"fa2==0.3.5\n",
"fancyimpute==0.4.3\n",
"fastai==1.0.61\n",
"fastdtw==0.3.4\n",
"fastprogress==1.0.0\n",
"fastrlock==0.5\n",
"fbprophet==0.7.1\n",
"feather-format==0.4.1\n",
"filelock==3.0.12\n",
"firebase-admin==4.4.0\n",
"fix-yahoo-finance==0.0.22\n",
"Flask==1.1.2\n",
"folium==0.8.3\n",
"future==0.16.0\n",
"gast==0.3.3\n",
"GDAL==2.2.2\n",
"gdown==3.6.4\n",
"gensim==3.6.0\n",
"geographiclib==1.50\n",
"geopy==1.17.0\n",
"gin-config==0.3.0\n",
"glob2==0.7\n",
"google==2.0.3\n",
"google-api-core==1.16.0\n",
"google-api-python-client==1.7.12\n",
"google-auth==1.17.2\n",
"google-auth-httplib2==0.0.4\n",
"google-auth-oauthlib==0.4.2\n",
"google-cloud-bigquery==1.21.0\n",
"google-cloud-bigquery-storage==1.1.0\n",
"google-cloud-core==1.0.3\n",
"google-cloud-datastore==1.8.0\n",
"google-cloud-firestore==1.7.0\n",
"google-cloud-language==1.2.0\n",
"google-cloud-storage==1.18.1\n",
"google-cloud-translate==1.5.0\n",
"google-colab==1.0.0\n",
"google-pasta==0.2.0\n",
"google-resumable-media==0.4.1\n",
"googleapis-common-protos==1.52.0\n",
"googledrivedownloader==0.4\n",
"graphviz==0.10.1\n",
"grpcio==1.33.2\n",
"gspread==3.0.1\n",
"gspread-dataframe==3.0.8\n",
"gym==0.17.3\n",
"h5py==2.10.0\n",
"HeapDict==1.0.1\n",
"holidays==0.10.3\n",
"holoviews==1.13.5\n",
"html5lib==1.0.1\n",
"httpimport==0.5.18\n",
"httplib2==0.17.4\n",
"httplib2shim==0.0.3\n",
"humanize==0.5.1\n",
"hyperopt==0.1.2\n",
"ideep4py==2.0.0.post3\n",
"idna==2.10\n",
"image==1.5.33\n",
"imageio==2.4.1\n",
"imagesize==1.2.0\n",
"imbalanced-learn==0.4.3\n",
"imblearn==0.0\n",
"imgaug==0.2.9\n",
"importlib-metadata==2.0.0\n",
"importlib-resources==3.3.0\n",
"imutils==0.5.3\n",
"inflect==2.1.0\n",
"iniconfig==1.1.1\n",
"intel-openmp==2020.0.133\n",
"intervaltree==2.1.0\n",
"ipykernel==4.10.1\n",
"ipython==5.5.0\n",
"ipython-genutils==0.2.0\n",
"ipython-sql==0.3.9\n",
"ipywidgets==7.5.1\n",
"itsdangerous==1.1.0\n",
"jax==0.2.4\n",
"jaxlib==0.1.56+cuda101\n",
"jdcal==1.4.1\n",
"jedi==0.17.2\n",
"jieba==0.42.1\n",
"Jinja2==2.11.2\n",
"joblib==0.17.0\n",
"jpeg4py==0.1.4\n",
"jsonschema==2.6.0\n",
"jupyter==1.0.0\n",
"jupyter-client==5.3.5\n",
"jupyter-console==5.2.0\n",
"jupyter-core==4.6.3\n",
"jupyterlab-pygments==0.1.2\n",
"kaggle==1.5.9\n",
"kapre==0.1.3.1\n",
"Keras==2.4.3\n",
"Keras-Preprocessing==1.1.2\n",
"keras-vis==0.4.1\n",
"kiwisolver==1.3.1\n",
"knnimpute==0.1.0\n",
"korean-lunar-calendar==0.2.1\n",
"librosa==0.6.3\n",
"lightgbm==2.2.3\n",
"llvmlite==0.31.0\n",
"lmdb==0.99\n",
"lucid==0.3.8\n",
"LunarCalendar==0.0.9\n",
"lxml==4.2.6\n",
"Markdown==3.3.3\n",
"MarkupSafe==1.1.1\n",
"matplotlib==3.2.2\n",
"matplotlib-venn==0.11.6\n",
"missingno==0.4.2\n",
"mistune==0.8.4\n",
"mizani==0.6.0\n",
"mkl==2019.0\n",
"mlxtend==0.14.0\n",
"more-itertools==8.6.0\n",
"moviepy==0.2.3.5\n",
"mpmath==1.1.0\n",
"msgpack==1.0.0\n",
"multiprocess==0.70.10\n",
"multitasking==0.0.9\n",
"murmurhash==1.0.3\n",
"music21==5.5.0\n",
"natsort==5.5.0\n",
"nbclient==0.5.1\n",
"nbconvert==5.6.1\n",
"nbformat==5.0.8\n",
"nest-asyncio==1.4.2\n",
"networkx==2.5\n",
"nibabel==3.0.2\n",
"nltk==3.2.5\n",
"notebook==5.3.1\n",
"np-utils==0.5.12.1\n",
"numba==0.48.0\n",
"numexpr==2.7.1\n",
"numpy==1.18.5\n",
"nvidia-ml-py3==7.352.0\n",
"oauth2client==4.1.3\n",
"oauthlib==3.1.0\n",
"okgrade==0.4.3\n",
"opencv-contrib-python==4.1.2.30\n",
"opencv-python==4.1.2.30\n",
"openpyxl==2.5.9\n",
"opt-einsum==3.3.0\n",
"osqp==0.6.1\n",
"packaging==20.4\n",
"palettable==3.3.0\n",
"pandas==1.1.4\n",
"pandas-datareader==0.9.0\n",
"pandas-gbq==0.13.3\n",
"pandas-profiling==1.4.1\n",
"pandocfilters==1.4.3\n",
"panel==0.9.7\n",
"param==1.10.0\n",
"parso==0.7.1\n",
"pathlib==1.0.1\n",
"patsy==0.5.1\n",
"pexpect==4.8.0\n",
"pickleshare==0.7.5\n",
"Pillow==7.0.0\n",
"pip-tools==4.5.1\n",
"plac==1.1.3\n",
"plotly==4.4.1\n",
"plotnine==0.6.0\n",
"pluggy==0.7.1\n",
"portpicker==1.3.1\n",
"prefetch-generator==1.0.1\n",
"preshed==3.0.2\n",
"prettytable==1.0.1\n",
"progressbar2==3.38.0\n",
"prometheus-client==0.8.0\n",
"promise==2.3\n",
"prompt-toolkit==1.0.18\n",
"protobuf==3.12.4\n",
"psutil==5.4.8\n",
"psycopg2==2.7.6.1\n",
"ptyprocess==0.6.0\n",
"py==1.9.0\n",
"pyarrow==0.14.1\n",
"pyasn1==0.4.8\n",
"pyasn1-modules==0.2.8\n",
"pycocotools==2.0.2\n",
"pycparser==2.20\n",
"pyct==0.4.8\n",
"pydata-google-auth==1.1.0\n",
"pydot==1.3.0\n",
"pydot-ng==2.0.0\n",
"pydotplus==2.0.2\n",
"PyDrive==1.3.1\n",
"pyemd==0.5.1\n",
"pyglet==1.5.0\n",
"Pygments==2.6.1\n",
"pygobject==3.26.1\n",
"pymc3==3.7\n",
"PyMeeus==0.3.7\n",
"pymongo==3.11.0\n",
"pymystem3==0.2.0\n",
"PyOpenGL==3.1.5\n",
"pyparsing==2.4.7\n",
"pyrsistent==0.17.3\n",
"pysndfile==1.3.8\n",
"PySocks==1.7.1\n",
"pystan==2.19.1.1\n",
"pytest==3.6.4\n",
"python-apt==1.6.5+ubuntu0.3\n",
"python-chess==0.23.11\n",
"python-dateutil==2.8.1\n",
"python-louvain==0.14\n",
"python-slugify==4.0.1\n",
"python-utils==2.4.0\n",
"pytz==2018.9\n",
"pyviz-comms==0.7.6\n",
"PyWavelets==1.1.1\n",
"PyYAML==3.13\n",
"pyzmq==19.0.2\n",
"qtconsole==4.7.7\n",
"QtPy==1.9.0\n",
"regex==2019.12.20\n",
"requests==2.23.0\n",
"requests-oauthlib==1.3.0\n",
"resampy==0.2.2\n",
"retrying==1.3.3\n",
"rpy2==3.2.7\n",
"rsa==4.6\n",
"scikit-image==0.16.2\n",
"scikit-learn==0.22.2.post1\n",
"scipy==1.4.1\n",
"screen-resolution-extra==0.0.0\n",
"scs==2.1.2\n",
"seaborn==0.11.0\n",
"Send2Trash==1.5.0\n",
"setuptools-git==1.2\n",
"Shapely==1.7.1\n",
"simplegeneric==0.8.1\n",
"six==1.15.0\n",
"sklearn==0.0\n",
"sklearn-pandas==1.8.0\n",
"slugify==0.0.1\n",
"smart-open==3.0.0\n",
"snowballstemmer==2.0.0\n",
"sortedcontainers==2.2.2\n",
"spacy==2.2.4\n",
"Sphinx==1.8.5\n",
"sphinxcontrib-serializinghtml==1.1.4\n",
"sphinxcontrib-websupport==1.2.4\n",
"SQLAlchemy==1.3.20\n",
"sqlparse==0.4.1\n",
"srsly==1.0.2\n",
"statsmodels==0.10.2\n",
"sympy==1.1.1\n",
"tables==3.4.4\n",
"tabulate==0.8.7\n",
"tblib==1.7.0\n",
"tensorboard==2.3.0\n",
"tensorboard-plugin-wit==1.7.0\n",
"tensorboardcolab==0.0.22\n",
"tensorflow==2.3.0\n",
"tensorflow-addons==0.8.3\n",
"tensorflow-datasets==4.0.1\n",
"tensorflow-estimator==2.3.0\n",
"tensorflow-gcs-config==2.3.0\n",
"tensorflow-hub==0.10.0\n",
"tensorflow-metadata==0.24.0\n",
"tensorflow-privacy==0.2.2\n",
"tensorflow-probability==0.11.0\n",
"termcolor==1.1.0\n",
"terminado==0.9.1\n",
"testpath==0.4.4\n",
"text-unidecode==1.3\n",
"textblob==0.15.3\n",
"textgenrnn==1.4.1\n",
"Theano==1.0.5\n",
"thinc==7.4.0\n",
"tifffile==2020.9.3\n",
"toml==0.10.2\n",
"toolz==0.11.1\n",
"torch==1.7.0+cu101\n",
"torchsummary==1.5.1\n",
"torchtext==0.3.1\n",
"torchvision==0.8.1+cu101\n",
"tornado==5.1.1\n",
"tqdm==4.41.1\n",
"traitlets==4.3.3\n",
"tweepy==3.6.0\n",
"typeguard==2.7.1\n",
"typing-extensions==3.7.4.3\n",
"tzlocal==1.5.1\n",
"umap-learn==0.4.6\n",
"uritemplate==3.0.1\n",
"urllib3==1.24.3\n",
"vega-datasets==0.8.0\n",
"wasabi==0.8.0\n",
"wcwidth==0.2.5\n",
"webencodings==0.5.1\n",
"Werkzeug==1.0.1\n",
"widgetsnbextension==3.5.1\n",
"wordcloud==1.5.0\n",
"wrapt==1.12.1\n",
"xarray==0.15.1\n",
"xgboost==0.90\n",
"xkit==0.0.0\n",
"xlrd==1.1.0\n",
"xlwt==1.3.0\n",
"yellowbrick==0.9.1\n",
"zict==2.0.0\n",
"zipp==3.4.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rix8GB6MASIt"
},
"source": [
"install pyspark "
]
},
{
"cell_type": "code",
"metadata": {
"id": "Y3NwtwtX-OJy",
"outputId": "207e49f7-3486-4c10-dee1-0c3b0f45c3a3",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!pip install pyspark"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting pyspark\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/f0/26/198fc8c0b98580f617cb03cb298c6056587b8f0447e20fa40c5b634ced77/pyspark-3.0.1.tar.gz (204.2MB)\n",
"\u001b[K |████████████████████████████████| 204.2MB 65kB/s \n",
"\u001b[?25hCollecting py4j==0.10.9\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/9e/b6/6a4fb90cd235dc8e265a6a2067f2a2c99f0d91787f06aca4bcf7c23f3f80/py4j-0.10.9-py2.py3-none-any.whl (198kB)\n",
"\u001b[K |████████████████████████████████| 204kB 48.8MB/s \n",
"\u001b[?25hBuilding wheels for collected packages: pyspark\n",
" Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pyspark: filename=pyspark-3.0.1-py2.py3-none-any.whl size=204612243 sha256=90e988a1b00429fba4fd379da2c69403ada6fd70e1e5e4ff06c2a4151d326e2d\n",
" Stored in directory: /root/.cache/pip/wheels/5e/bd/07/031766ca628adec8435bb40f0bd83bb676ce65ff4007f8e73f\n",
"Successfully built pyspark\n",
"Installing collected packages: py4j, pyspark\n",
"Successfully installed py4j-0.10.9 pyspark-3.0.1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "oHooHpDl_p79",
"outputId": "18ea5873-5da6-45de-a47c-b5ef7a805c1f",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!pip install findspark"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting findspark\n",
" Downloading https://files.pythonhosted.org/packages/fc/2d/2e39f9a023479ea798eed4351cd66f163ce61e00c717e03c37109f00c0f2/findspark-1.4.2-py2.py3-none-any.whl\n",
"Installing collected packages: findspark\n",
"Successfully installed findspark-1.4.2\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VRXYK081_9Ur",
"outputId": "d96c157f-0cd6-4c32-e532-677e5fb1e5a0",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!ls /content/"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"sample_data\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Y7B36uTgBwyp",
"outputId": "8656de04-02ce-4485-92a0-d45ae89ac898",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!pip show pyspark"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Name: pyspark\n",
"Version: 3.0.1\n",
"Summary: Apache Spark Python API\n",
"Home-page: https://github.com/apache/spark/tree/master/python\n",
"Author: Spark Developers\n",
"Author-email: dev@spark.apache.org\n",
"License: http://www.apache.org/licenses/LICENSE-2.0\n",
"Location: /usr/local/lib/python3.6/dist-packages\n",
"Requires: py4j\n",
"Required-by: \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "SoBBj8KaCQbu",
"outputId": "f2df4cc5-c150-4143-b716-70797cd61a56",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!ls /usr/local/lib/python3.6/dist-packages/pyspark/python"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"lib pyspark\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Havaj-cPDQe8",
"outputId": "7f1fc12d-b660-4075-8cc7-d7f6becb30d7",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!ls /usr/lib/jvm/"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"default-java java-1.11.0-openjdk-amd64 java-11-openjdk-amd64\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "gpCZE0jv_Tx6",
"outputId": "87994e2d-2dfa-469d-acc3-1bbe3bd1d1b6",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"import os\n",
"jdk_list = os.listdir(\"/usr/lib/jvm/\")\n",
"jdk_list"
],
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['default-java',\n",
" '.java-1.11.0-openjdk-amd64.jinfo',\n",
" 'java-1.11.0-openjdk-amd64',\n",
" 'java-11-openjdk-amd64']"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SL6bqs1tAeMz"
},
"source": [
"setup path variables"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Am5u2mo5AKQ9"
},
"source": [
"os.environ[\"JAVA_HOME\"] = f\"/usr/lib/jvm/{jdk_list[-1]}\"\n",
"os.environ[\"SPARK_HOME\"] = \"/usr/local/lib/python3.6/dist-packages/pyspark\""
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "w8TrXIEDAgtN"
},
"source": [
"import pyspark"
]
},
{
"cell_type": "code",
"metadata": {
"id": "G0PzfTvXBkGU"
},
"source": [
"import findspark\n",
"findspark.init(\"/usr/local/lib/python3.6/dist-packages/pyspark\")"
],
"execution_count": 15,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "iElSPkwKAmV_"
},
"source": [
"create spark session"
]
},
{
"cell_type": "code",
"metadata": {
"id": "8HttnN8-CWtV"
},
"source": [
"from pyspark.sql import SparkSession\n",
"spark = SparkSession.builder.master(\"local[*]\").getOrCreate()"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJG62l5gAoNE"
},
"source": [
"read data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ixB6XiJrDYI_",
"outputId": "bdd9e555-c6f8-40c0-db0f-4407314e2d01",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"#print(os.listdir('./sample_data'))\n",
"file_loc = \"./sample_data/california_housing_train.csv\"\n",
"df_spark = spark.read.csv(file_loc, inferSchema=True, header=True)\n",
"print(type(df_spark))"
],
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"text": [
"<class 'pyspark.sql.dataframe.DataFrame'>\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "SYojmKy6FZdR",
"outputId": "83435b7c-9b38-4a8e-80af-d2cc45a55637",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"df_spark.printSchema() # print detail schema of data\n",
"df_spark.show()# show top 20 rows"
],
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"text": [
"root\n",
" |-- longitude: double (nullable = true)\n",
" |-- latitude: double (nullable = true)\n",
" |-- housing_median_age: double (nullable = true)\n",
" |-- total_rooms: double (nullable = true)\n",
" |-- total_bedrooms: double (nullable = true)\n",
" |-- population: double (nullable = true)\n",
" |-- households: double (nullable = true)\n",
" |-- median_income: double (nullable = true)\n",
" |-- median_house_value: double (nullable = true)\n",
"\n",
"+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+\n",
"|longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|\n",
"+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+\n",
"| -114.31| 34.19| 15.0| 5612.0| 1283.0| 1015.0| 472.0| 1.4936| 66900.0|\n",
"| -114.47| 34.4| 19.0| 7650.0| 1901.0| 1129.0| 463.0| 1.82| 80100.0|\n",
"| -114.56| 33.69| 17.0| 720.0| 174.0| 333.0| 117.0| 1.6509| 85700.0|\n",
"| -114.57| 33.64| 14.0| 1501.0| 337.0| 515.0| 226.0| 3.1917| 73400.0|\n",
"| -114.57| 33.57| 20.0| 1454.0| 326.0| 624.0| 262.0| 1.925| 65500.0|\n",
"| -114.58| 33.63| 29.0| 1387.0| 236.0| 671.0| 239.0| 3.3438| 74000.0|\n",
"| -114.58| 33.61| 25.0| 2907.0| 680.0| 1841.0| 633.0| 2.6768| 82400.0|\n",
"| -114.59| 34.83| 41.0| 812.0| 168.0| 375.0| 158.0| 1.7083| 48500.0|\n",
"| -114.59| 33.61| 34.0| 4789.0| 1175.0| 3134.0| 1056.0| 2.1782| 58400.0|\n",
"| -114.6| 34.83| 46.0| 1497.0| 309.0| 787.0| 271.0| 2.1908| 48100.0|\n",
"| -114.6| 33.62| 16.0| 3741.0| 801.0| 2434.0| 824.0| 2.6797| 86500.0|\n",
"| -114.6| 33.6| 21.0| 1988.0| 483.0| 1182.0| 437.0| 1.625| 62000.0|\n",
"| -114.61| 34.84| 48.0| 1291.0| 248.0| 580.0| 211.0| 2.1571| 48600.0|\n",
"| -114.61| 34.83| 31.0| 2478.0| 464.0| 1346.0| 479.0| 3.212| 70400.0|\n",
"| -114.63| 32.76| 15.0| 1448.0| 378.0| 949.0| 300.0| 0.8585| 45000.0|\n",
"| -114.65| 34.89| 17.0| 2556.0| 587.0| 1005.0| 401.0| 1.6991| 69100.0|\n",
"| -114.65| 33.6| 28.0| 1678.0| 322.0| 666.0| 256.0| 2.9653| 94900.0|\n",
"| -114.65| 32.79| 21.0| 44.0| 33.0| 64.0| 27.0| 0.8571| 25000.0|\n",
"| -114.66| 32.74| 17.0| 1388.0| 386.0| 775.0| 320.0| 1.2049| 44000.0|\n",
"| -114.67| 33.92| 17.0| 97.0| 24.0| 29.0| 15.0| 1.2656| 27500.0|\n",
"+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+\n",
"only showing top 20 rows\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kfVdSIUTAvrk"
},
"source": [
"Download data from online sources"
]
},
{
"cell_type": "code",
"metadata": {
"id": "25sjp3HmFi4W",
"outputId": "f0c26ece-b031-4c2b-c4dc-31876f3bc146",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!curl https://raw.githubusercontent.com/h2oai/h2o-2/master/smalldata/BostonHousing.csv -o BostonHousing.csv"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 34695 100 34695 0 0 149k 0 --:--:-- --:--:-- --:--:-- 149k\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sc9j6IxaA06m"
},
"source": [
"ML pipeline setup for decision tree classification on the data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "balmJ50WGLvC"
},
"source": [
"from pyspark.ml.feature import VectorAssembler\n",
"from pyspark.ml.regression import LinearRegression\n",
"dataset = spark.read.csv('BostonHousing.csv',inferSchema=True, header =True)"
],
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "eyv78pxzGdjY",
"outputId": "48325505-0dac-4450-cc0c-d40e87ac673c",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"dataset.printSchema()"
],
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"text": [
"root\n",
" |-- crim: double (nullable = true)\n",
" |-- zn: double (nullable = true)\n",
" |-- indus: double (nullable = true)\n",
" |-- chas: integer (nullable = true)\n",
" |-- nox: double (nullable = true)\n",
" |-- rm: double (nullable = true)\n",
" |-- age: double (nullable = true)\n",
" |-- dis: double (nullable = true)\n",
" |-- rad: integer (nullable = true)\n",
" |-- tax: integer (nullable = true)\n",
" |-- ptratio: double (nullable = true)\n",
" |-- b: double (nullable = true)\n",
" |-- lstat: double (nullable = true)\n",
" |-- medv: double (nullable = true)\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "tj-neS2fGg7G",
"outputId": "9abab1e6-b0b7-4850-8832-110f20f7958d",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"dataset.show()"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"text": [
"+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+\n",
"| crim| zn|indus|chas| nox| rm| age| dis|rad|tax|ptratio| b|lstat|medv|\n",
"+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+\n",
"|0.00632|18.0| 2.31| 0|0.538|6.575| 65.2| 4.09| 1|296| 15.3| 396.9| 4.98|24.0|\n",
"|0.02731| 0.0| 7.07| 0|0.469|6.421| 78.9|4.9671| 2|242| 17.8| 396.9| 9.14|21.6|\n",
"|0.02729| 0.0| 7.07| 0|0.469|7.185| 61.1|4.9671| 2|242| 17.8|392.83| 4.03|34.7|\n",
"|0.03237| 0.0| 2.18| 0|0.458|6.998| 45.8|6.0622| 3|222| 18.7|394.63| 2.94|33.4|\n",
"|0.06905| 0.0| 2.18| 0|0.458|7.147| 54.2|6.0622| 3|222| 18.7| 396.9| 5.33|36.2|\n",
"|0.02985| 0.0| 2.18| 0|0.458| 6.43| 58.7|6.0622| 3|222| 18.7|394.12| 5.21|28.7|\n",
"|0.08829|12.5| 7.87| 0|0.524|6.012| 66.6|5.5605| 5|311| 15.2| 395.6|12.43|22.9|\n",
"|0.14455|12.5| 7.87| 0|0.524|6.172| 96.1|5.9505| 5|311| 15.2| 396.9|19.15|27.1|\n",
"|0.21124|12.5| 7.87| 0|0.524|5.631|100.0|6.0821| 5|311| 15.2|386.63|29.93|16.5|\n",
"|0.17004|12.5| 7.87| 0|0.524|6.004| 85.9|6.5921| 5|311| 15.2|386.71| 17.1|18.9|\n",
"|0.22489|12.5| 7.87| 0|0.524|6.377| 94.3|6.3467| 5|311| 15.2|392.52|20.45|15.0|\n",
"|0.11747|12.5| 7.87| 0|0.524|6.009| 82.9|6.2267| 5|311| 15.2| 396.9|13.27|18.9|\n",
"|0.09378|12.5| 7.87| 0|0.524|5.889| 39.0|5.4509| 5|311| 15.2| 390.5|15.71|21.7|\n",
"|0.62976| 0.0| 8.14| 0|0.538|5.949| 61.8|4.7075| 4|307| 21.0| 396.9| 8.26|20.4|\n",
"|0.63796| 0.0| 8.14| 0|0.538|6.096| 84.5|4.4619| 4|307| 21.0|380.02|10.26|18.2|\n",
"|0.62739| 0.0| 8.14| 0|0.538|5.834| 56.5|4.4986| 4|307| 21.0|395.62| 8.47|19.9|\n",
"|1.05393| 0.0| 8.14| 0|0.538|5.935| 29.3|4.4986| 4|307| 21.0|386.85| 6.58|23.1|\n",
"| 0.7842| 0.0| 8.14| 0|0.538| 5.99| 81.7|4.2579| 4|307| 21.0|386.75|14.67|17.5|\n",
"|0.80271| 0.0| 8.14| 0|0.538|5.456| 36.6|3.7965| 4|307| 21.0|288.99|11.69|20.2|\n",
"| 0.7258| 0.0| 8.14| 0|0.538|5.727| 69.5|3.7965| 4|307| 21.0|390.95|11.28|18.2|\n",
"+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+\n",
"only showing top 20 rows\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "T7utgfOtMP-b"
},
"source": [
"from pyspark.sql import functions as F"
],
"execution_count": 23,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "RqBkmYm3CLEH"
},
"source": [
"creating a binary lable for classification task\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4ra2vfQyLdI4"
},
"source": [
"dataset = dataset.withColumn(\"medv\", F.col(\"medv\")>20).withColumn(\"medv\", F.when(F.col(\"medv\")=='false', 0).otherwise(1))"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WVdT3051RqEP"
},
"source": [
"inputCols=['crim', 'zn', 'indus', 'chas', 'nox', 'rm', 'age', 'dis', 'rad', 'tax', 'ptratio', 'b', 'lstat']\n",
"outputCol = 'Attributes'"
],
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0S2nhV2BGn4d",
"outputId": "89bcb016-48ce-4c0a-f5e5-908a85c2acdc",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"#Input all the features in one vector column\n",
"assembler = VectorAssembler(inputCols=['crim', 'zn', 'indus', 'chas', 'nox', 'rm', 'age', 'dis', 'rad', 'tax', 'ptratio', 'b', 'lstat'], outputCol = 'Attributes')\n",
"output = assembler.transform(dataset)\n",
"#Input vs Output\n",
"finalized_data = output.select(\"Attributes\",\"medv\")\n",
"finalized_data.show()"
],
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"text": [
"+--------------------+----+\n",
"| Attributes|medv|\n",
"+--------------------+----+\n",
"|[0.00632,18.0,2.3...| 1|\n",
"|[0.02731,0.0,7.07...| 1|\n",
"|[0.02729,0.0,7.07...| 1|\n",
"|[0.03237,0.0,2.18...| 1|\n",
"|[0.06905,0.0,2.18...| 1|\n",
"|[0.02985,0.0,2.18...| 1|\n",
"|[0.08829,12.5,7.8...| 1|\n",
"|[0.14455,12.5,7.8...| 1|\n",
"|[0.21124,12.5,7.8...| 0|\n",
"|[0.17004,12.5,7.8...| 0|\n",
"|[0.22489,12.5,7.8...| 0|\n",
"|[0.11747,12.5,7.8...| 0|\n",
"|[0.09378,12.5,7.8...| 1|\n",
"|[0.62976,0.0,8.14...| 1|\n",
"|[0.63796,0.0,8.14...| 0|\n",
"|[0.62739,0.0,8.14...| 0|\n",
"|[1.05393,0.0,8.14...| 1|\n",
"|[0.7842,0.0,8.14,...| 0|\n",
"|[0.80271,0.0,8.14...| 1|\n",
"|[0.7258,0.0,8.14,...| 0|\n",
"+--------------------+----+\n",
"only showing top 20 rows\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4sQgPDDuCYcg"
},
"source": [
"Instantiate spark decision tree classifier"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5k_BJarMKF5E"
},
"source": [
"from pyspark.ml.classification import DecisionTreeClassifier\n",
"\n",
"dtree = DecisionTreeClassifier(featuresCol=\"Attributes\",\n",
" labelCol=\"medv\",\n",
" maxDepth=3,\n",
" maxBins=50)"
],
"execution_count": 27,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "2KohXBxbCa0m"
},
"source": [
"Train the model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7OEnDekAKywa"
},
"source": [
"# output = assembler.transform(dataset)\n",
"dtree_model = dtree.fit(finalized_data)"
],
"execution_count": 28,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "oSCsjge5Cdld"
},
"source": [
"Save the model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "lUSoxjkFNxwM"
},
"source": [
"model_path = \"dtree_model_01\"\n",
"dtree_model.save(model_path)"
],
"execution_count": 29,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "GozGTUW_Cf3w"
},
"source": [
"Reload the saved model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "9ppFOqciKe-Z"
},
"source": [
"from pyspark.ml.classification import DecisionTreeClassificationModel"
],
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2KavPGOrKfD5"
},
"source": [
"dtree_model = DecisionTreeClassificationModel.load(model_path)"
],
"execution_count": 31,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "eDKnk7clCihh"
},
"source": [
"Check the decision tree rules "
]
},
{
"cell_type": "code",
"metadata": {
"id": "-rRioaA1Ie9z",
"outputId": "1a803e0b-392c-4efc-fedb-24b38072ebca",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"print(dtree_model.toDebugString)"
],
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"text": [
"DecisionTreeClassificationModel: uid=DecisionTreeClassifier_b4b15a225fc1, depth=3, numNodes=13, numClasses=2, numFeatures=13\n",
" If (feature 12 <= 14.23)\n",
" If (feature 5 <= 6.0440000000000005)\n",
" If (feature 7 <= 4.7833000000000006)\n",
" Predict: 1.0\n",
" Else (feature 7 > 4.7833000000000006)\n",
" Predict: 0.0\n",
" Else (feature 5 > 6.0440000000000005)\n",
" Predict: 1.0\n",
" Else (feature 12 > 14.23)\n",
" If (feature 9 <= 298.0)\n",
" If (feature 6 <= 94.45)\n",
" Predict: 1.0\n",
" Else (feature 6 > 94.45)\n",
" Predict: 0.0\n",
" Else (feature 9 > 298.0)\n",
" If (feature 6 <= 57.3)\n",
" Predict: 1.0\n",
" Else (feature 6 > 57.3)\n",
" Predict: 0.0\n",
"\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment