Forked from robrich/linear-regression-pushdown-disabled.ipynb
Created
October 12, 2021 16:33
-
-
Save VeryFatBoy/ac405b9ea2ada27b1fab0ae20f969855 to your computer and use it in GitHub Desktop.
Training and testing a Machine Learning model with Spark and MemSQL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Training and testing a model with Spark and MemSQL - Pushdown Disabled" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Setup Spark Context" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages \"com.memsql:memsql-spark-connector_2.11:3.0.0-rc1-spark-2.4.4\" pyspark-shell'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pyspark\n", | |
"spark = pyspark.sql.SparkSession.builder.master(\"local[*]\").getOrCreate()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Connect to MemSQL" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"spark.conf.set(\"spark.datasource.memsql.ddlEndpoint\", \"localhost\")\n", | |
"spark.conf.set(\"spark.datasource.memsql.user\", \"root\")\n", | |
"spark.conf.set(\"spark.datasource.memsql.password\", \"\")\n", | |
"spark.conf.set(\"spark.datasource.memsql.disablePushdown\", \"true\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Query data from MemSQL Table" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = spark.read.format(\"memsql\") \\\n", | |
" .load(\"tpch.lineitem\") \\\n", | |
" .select('l_partkey','l_suppkey','l_quantity','l_discount','l_tax','l_extendedprice') \\\n", | |
" .limit(1000000)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Assemble features vector (columns used as predictors in model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"feature_columns = ['l_partkey','l_suppkey','l_quantity','l_discount','l_tax']\n", | |
"from pyspark.ml.feature import VectorAssembler\n", | |
"assembler = VectorAssembler(inputCols=feature_columns,outputCol=\"features\")\n", | |
"data_2 = assembler.transform(data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Separate the data used for training vs. testing model, and run linear regression. \n", | |
"Measure timespan for performance benchmarking." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train, test = data_2.randomSplit([0.7, 0.3])\n", | |
"from pyspark.ml.regression import LinearRegression\n", | |
"algo = LinearRegression(featuresCol=\"features\", labelCol=\"l_extendedprice\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Train the model.\n", | |
"Note: Check out MemSQL Studio's Resource Usage to see the queries run." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 25.9 ms, sys: 9.27 ms, total: 35.1 ms\n", | |
"Wall time: 1min 54s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"model = algo.fit(train)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Collect model metrics and make predictions." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 13.2 ms, sys: 7.31 ms, total: 20.5 ms\n", | |
"Wall time: 1min 40s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# evaluation\n", | |
"evaluation_summary = model.evaluate(test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 5.59 ms, sys: 2.81 ms, total: 8.4 ms\n", | |
"Wall time: 52 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# predicting values\n", | |
"predictions = model.transform(test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.8626332754492319" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"r_squared = evaluation_summary.r2\n", | |
"r_squared" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Show predictions." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"+---------+---------+----------+----------+-----+------------------+\n", | |
"|l_partkey|l_suppkey|l_quantity|l_discount|l_tax| prediction|\n", | |
"+---------+---------+----------+----------+-----+------------------+\n", | |
"| 18| 500019| 44.00| 0.08| 0.00| 65970.22715696078|\n", | |
"| 38| 750039| 46.00| 0.02| 0.01| 68952.20722547553|\n", | |
"| 74| 75| 1.00| 0.08| 0.05|1471.1811359388616|\n", | |
"| 97| 500098| 8.00| 0.01| 0.03|11960.307478002207|\n", | |
"| 158| 250159| 32.00| 0.00| 0.00| 47914.35790036007|\n", | |
"| 297| 500298| 6.00| 0.07| 0.06| 8998.83993524212|\n", | |
"| 356| 500357| 32.00| 0.02| 0.05|47947.151749741795|\n", | |
"| 380| 750381| 39.00| 0.00| 0.05|58448.125668747765|\n", | |
"| 389| 750390| 7.00| 0.03| 0.08|10493.942591755862|\n", | |
"| 395| 750396| 46.00| 0.10| 0.02| 68999.96446936413|\n", | |
"| 500| 750501| 22.00| 0.07| 0.03|33002.734381948365|\n", | |
"| 523| 524| 48.00| 0.09| 0.06| 71938.38862559927|\n", | |
"| 539| 500540| 49.00| 0.04| 0.07| 73445.73663812922|\n", | |
"| 567| 750568| 5.00| 0.06| 0.02| 7510.543763732859|\n", | |
"| 629| 750630| 44.00| 0.04| 0.00| 65965.26884615852|\n", | |
"| 691| 250692| 15.00| 0.00| 0.02|22429.512743377418|\n", | |
"| 758| 759| 45.00| 0.02| 0.02| 67397.68464167765|\n", | |
"| 925| 250926| 12.00| 0.06| 0.02| 17967.47984263261|\n", | |
"| 1006| 1007| 29.00| 0.03| 0.01| 43416.48942095604|\n", | |
"| 1023| 1024| 41.00| 0.05| 0.04| 61419.69940025381|\n", | |
"+---------+---------+----------+----------+-----+------------------+\n", | |
"only showing top 20 rows\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"predictions \\\n", | |
" .select(predictions['l_partkey'],predictions['l_suppkey'] \\\n", | |
" ,predictions['l_quantity'],predictions['l_discount'],predictions['l_tax'],predictions['prediction']) \\\n", | |
" .show() " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment