Created
June 4, 2020 00:49
-
-
Save robrich/b1841a1c0a76070e22353c3d7f586d76 to your computer and use it in GitHub Desktop.
Training and testing a Machine Learning model with Spark and MemSQL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Training and testing a model with Spark and MemSQL - Pushdown Enabled" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Setup Spark Context" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages \"com.memsql:memsql-spark-connector_2.11:3.0.0-rc1-spark-2.4.4\" pyspark-shell'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pyspark\n", | |
"spark = pyspark.sql.SparkSession.builder.master(\"local[*]\").getOrCreate()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Connect to MemSQL" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"spark.conf.set(\"spark.datasource.memsql.ddlEndpoint\", \"localhost\")\n", | |
"spark.conf.set(\"spark.datasource.memsql.user\", \"root\")\n", | |
"spark.conf.set(\"spark.datasource.memsql.password\", \"\")\n", | |
"\n", | |
"spark.conf.set(\"spark.datasource.memsql.disablePushdown\", \"false\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Query data from MemSQL Table" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = spark.read.format(\"memsql\") \\\n", | |
" .load(\"tpch.lineitem_bu\") \\\n", | |
" .select('l_partkey','l_suppkey','l_quantity','l_discount','l_tax','l_extendedprice') \\\n", | |
" .limit(1000000)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Assemble features vector (columns used as predictors in model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"feature_columns = ['l_partkey','l_suppkey','l_quantity','l_discount','l_tax']\n", | |
"from pyspark.ml.feature import VectorAssembler\n", | |
"assembler = VectorAssembler(inputCols=feature_columns,outputCol=\"features\")\n", | |
"data_2 = assembler.transform(data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Separate the data used for training vs. testing model, and run linear regression. \n", | |
"Measure timespan for performance benchmarking." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train, test = data_2.randomSplit([0.7, 0.3])\n", | |
"from pyspark.ml.regression import LinearRegression\n", | |
"algo = LinearRegression(featuresCol=\"features\", labelCol=\"l_extendedprice\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Train the model. Note: Check out MemSQL Studio's Resource Usage to see the queries run." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 13.7 ms, sys: 6.57 ms, total: 20.2 ms\n", | |
"Wall time: 29.9 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"model = algo.fit(train)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Collect model metrics and make predictions." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 3.33 ms, sys: 1.93 ms, total: 5.26 ms\n", | |
"Wall time: 13.4 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# evaluation\n", | |
"evaluation_summary = model.evaluate(test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 7.42 ms, sys: 558 µs, total: 7.98 ms\n", | |
"Wall time: 57.2 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# predicting values\n", | |
"predictions = model.transform(test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.8629420794044855" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"r_squared = evaluation_summary.r2\n", | |
"r_squared" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Show predictions." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"+---------+---------+----------+----------+-----+------------------+\n", | |
"|l_partkey|l_suppkey|l_quantity|l_discount|l_tax| prediction|\n", | |
"+---------+---------+----------+----------+-----+------------------+\n", | |
"| 38| 750039| 46.00| 0.02| 0.01| 68977.74461490566|\n", | |
"| 74| 75| 1.00| 0.08| 0.05|1447.3229085589162|\n", | |
"| 158| 250159| 32.00| 0.00| 0.00| 47934.72253363739|\n", | |
"| 251| 252| 11.00| 0.09| 0.05| 16448.97259354734|\n", | |
"| 297| 500298| 6.00| 0.07| 0.06| 8988.435249517062|\n", | |
"| 318| 319| 29.00| 0.08| 0.05| 43451.31132863032|\n", | |
"| 518| 519| 6.00| 0.02| 0.03| 8933.819999290781|\n", | |
"| 521| 500522| 22.00| 0.08| 0.06| 32990.93922878267|\n", | |
"| 523| 524| 48.00| 0.09| 0.06| 71960.70395181864|\n", | |
"| 541| 500542| 38.00| 0.06| 0.07|56999.236464156704|\n", | |
"| 903| 500904| 5.00| 0.07| 0.05| 7481.871057433555|\n", | |
"| 1005| 751006| 9.00| 0.00| 0.06|13504.448001676637|\n", | |
"| 1010| 501011| 12.00| 0.10| 0.00|17951.206135025164|\n", | |
"| 1023| 1024| 41.00| 0.05| 0.04|61445.940888421224|\n", | |
"| 1049| 751050| 30.00| 0.05| 0.04| 44995.59580030811|\n", | |
"| 1070| 501071| 34.00| 0.03| 0.05|50985.119937364674|\n", | |
"| 1179| 751180| 19.00| 0.05| 0.00|28468.186813916644|\n", | |
"| 1189| 251190| 22.00| 0.06| 0.01| 32941.15426202936|\n", | |
"| 1198| 1199| 24.00| 0.10| 0.01|35925.243484515195|\n", | |
"| 1292| 751293| 5.00| 0.06| 0.02| 7479.356833333768|\n", | |
"+---------+---------+----------+----------+-----+------------------+\n", | |
"only showing top 20 rows\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"predictions \\\n", | |
" .select(predictions['l_partkey'],predictions['l_suppkey'] \\\n", | |
" ,predictions['l_quantity'],predictions['l_discount'],predictions['l_tax'],predictions['prediction']) \\\n", | |
" .show() " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment