Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save robrich/b1841a1c0a76070e22353c3d7f586d76 to your computer and use it in GitHub Desktop.
Save robrich/b1841a1c0a76070e22353c3d7f586d76 to your computer and use it in GitHub Desktop.
Training and testing a Machine Learning model with Spark and MemSQL
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training and testing a model with Spark and MemSQL - Pushdown Enabled"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Setup Spark Context"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages \"com.memsql:memsql-spark-connector_2.11:3.0.0-rc1-spark-2.4.4\" pyspark-shell'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pyspark\n",
"spark = pyspark.sql.SparkSession.builder.master(\"local[*]\").getOrCreate()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Connect to MemSQL"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"spark.conf.set(\"spark.datasource.memsql.ddlEndpoint\", \"localhost\")\n",
"spark.conf.set(\"spark.datasource.memsql.user\", \"root\")\n",
"spark.conf.set(\"spark.datasource.memsql.password\", \"\")\n",
"\n",
"spark.conf.set(\"spark.datasource.memsql.disablePushdown\", \"false\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Query data from MemSQL Table"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"data = spark.read.format(\"memsql\") \\\n",
" .load(\"tpch.lineitem_bu\") \\\n",
" .select('l_partkey','l_suppkey','l_quantity','l_discount','l_tax','l_extendedprice') \\\n",
" .limit(1000000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Assemble features vector (columns used as predictors in model)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"feature_columns = ['l_partkey','l_suppkey','l_quantity','l_discount','l_tax']\n",
"from pyspark.ml.feature import VectorAssembler\n",
"assembler = VectorAssembler(inputCols=feature_columns,outputCol=\"features\")\n",
"data_2 = assembler.transform(data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Separate the data used for training vs. testing model, and run linear regression. \n",
"Measure timespan for performance benchmarking."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train, test = data_2.randomSplit([0.7, 0.3])\n",
"from pyspark.ml.regression import LinearRegression\n",
"algo = LinearRegression(featuresCol=\"features\", labelCol=\"l_extendedprice\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the model. Note: Check out MemSQL Studio's Resource Usage to see the queries run."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 13.7 ms, sys: 6.57 ms, total: 20.2 ms\n",
"Wall time: 29.9 s\n"
]
}
],
"source": [
"%%time\n",
"model = algo.fit(train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Collect model metrics and make predictions."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.33 ms, sys: 1.93 ms, total: 5.26 ms\n",
"Wall time: 13.4 s\n"
]
}
],
"source": [
"%%time\n",
"# evaluation\n",
"evaluation_summary = model.evaluate(test)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7.42 ms, sys: 558 µs, total: 7.98 ms\n",
"Wall time: 57.2 ms\n"
]
}
],
"source": [
"%%time\n",
"# predicting values\n",
"predictions = model.transform(test)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8629420794044855"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"r_squared = evaluation_summary.r2\n",
"r_squared"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Show predictions."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+---------+---------+----------+----------+-----+------------------+\n",
"|l_partkey|l_suppkey|l_quantity|l_discount|l_tax| prediction|\n",
"+---------+---------+----------+----------+-----+------------------+\n",
"| 38| 750039| 46.00| 0.02| 0.01| 68977.74461490566|\n",
"| 74| 75| 1.00| 0.08| 0.05|1447.3229085589162|\n",
"| 158| 250159| 32.00| 0.00| 0.00| 47934.72253363739|\n",
"| 251| 252| 11.00| 0.09| 0.05| 16448.97259354734|\n",
"| 297| 500298| 6.00| 0.07| 0.06| 8988.435249517062|\n",
"| 318| 319| 29.00| 0.08| 0.05| 43451.31132863032|\n",
"| 518| 519| 6.00| 0.02| 0.03| 8933.819999290781|\n",
"| 521| 500522| 22.00| 0.08| 0.06| 32990.93922878267|\n",
"| 523| 524| 48.00| 0.09| 0.06| 71960.70395181864|\n",
"| 541| 500542| 38.00| 0.06| 0.07|56999.236464156704|\n",
"| 903| 500904| 5.00| 0.07| 0.05| 7481.871057433555|\n",
"| 1005| 751006| 9.00| 0.00| 0.06|13504.448001676637|\n",
"| 1010| 501011| 12.00| 0.10| 0.00|17951.206135025164|\n",
"| 1023| 1024| 41.00| 0.05| 0.04|61445.940888421224|\n",
"| 1049| 751050| 30.00| 0.05| 0.04| 44995.59580030811|\n",
"| 1070| 501071| 34.00| 0.03| 0.05|50985.119937364674|\n",
"| 1179| 751180| 19.00| 0.05| 0.00|28468.186813916644|\n",
"| 1189| 251190| 22.00| 0.06| 0.01| 32941.15426202936|\n",
"| 1198| 1199| 24.00| 0.10| 0.01|35925.243484515195|\n",
"| 1292| 751293| 5.00| 0.06| 0.02| 7479.356833333768|\n",
"+---------+---------+----------+----------+-----+------------------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
"predictions \\\n",
" .select(predictions['l_partkey'],predictions['l_suppkey'] \\\n",
" ,predictions['l_quantity'],predictions['l_discount'],predictions['l_tax'],predictions['prediction']) \\\n",
" .show() "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment