Skip to content

Instantly share code, notes, and snippets.

@pschatzmann
Created November 1, 2018 04:44
Show Gist options
  • Save pschatzmann/72d13ebfa64472cf7986659c27788516 to your computer and use it in GitHub Desktop.
Save pschatzmann/72d13ebfa64472cf7986659c27788516 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{"metadata":{"kernelspec":{"display_name":"Scala","language":"scala","name":"scala"},"language_info":{"codemirror_mode":"text/x-scala","file_extension":".scala","mimetype":"","name":"Scala","nbconverter_exporter":"","version":"2.11.12"}},"nbformat_minor":2,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# RandomForrest with MLib \n\nIn this document we show how to use Spark's MLib Machine Learning functionality in Scala using Jupyter with the BeakerX kernel (http://beakerx.com/).\n\n## Setup ##","metadata":{}},{"cell_type":"code","source":"%%classpath add mvn \norg.apache.spark:spark-sql_2.11:2.3.2\norg.apache.spark:spark-mllib_2.11:2.3.2\n","metadata":{"trusted":true},"execution_count":198,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"","version_major":2,"version_minor":0},"method":"display_data"},"metadata":{}}]},{"cell_type":"markdown","source":"Spark is a little bit too chatty. So we just want to see the error messages:","metadata":{}},{"cell_type":"code","source":"org.apache.log4j.Logger.getRootLogger().setLevel(org.apache.log4j.Level.ERROR);\n","metadata":{"trusted":true},"execution_count":214,"outputs":[]},{"cell_type":"code","source":"import org.apache.spark.sql.SparkSession\n\nval spark = SparkSession.builder()\n .appName(\"Simple Application\")\n .master(\"local[4]\")\n .config(\"spark.ui.enabled\", \"false\")\n .getOrCreate()","metadata":{"trusted":true},"execution_count":215,"outputs":[{"execution_count":215,"output_type":"execute_result","data":{"text/plain":"org.apache.spark.sql.SparkSession@3807c582"},"metadata":{}}]},{"cell_type":"markdown","source":"## Data Preparation ##\n\nTo use the data in MLib we need to have the features in a Vector. So first we need to get and pre-process the data so that we have the requested format.\n\nIn Spark it is not possible to read from a URL. Therefore we download the data from the internet and create a file first.","metadata":{}},{"cell_type":"code","source":"import java.io.File\nimport java.io.PrintWriter\nimport scala.io.Source\n\nvar file = new File(\"iris.csv\")\nif (!file.exists()) {\n val url = \"https://gist.githubusercontent.com/netj/8836201/raw/6f9306ad21398ea43cba4f7d537619d0e07d5ae3/iris.csv\"\n val csv = Source.fromURL(url).mkString\n val writer = new PrintWriter(file)\n writer.write(csv)\n writer.close()\n}\n\nfile.exists","metadata":{"trusted":true},"execution_count":228,"outputs":[{"execution_count":228,"output_type":"execute_result","data":{"text/plain":"true"},"metadata":{}}]},{"cell_type":"markdown","source":"Now we can read the file into a spark.sql.Dataset","metadata":{}},{"cell_type":"code","source":"val txt = spark.read\n .format(\"csv\")\n .option(\"header\", \"true\") \n .option(\"mode\", \"DROPMALFORMED\")\n .load(file.toString())\n\ntxt.printSchema()\ntxt.getClass()","metadata":{"trusted":true},"execution_count":219,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepal.length: string (nullable = true)\n |-- sepal.width: string (nullable = true)\n |-- petal.length: string (nullable = true)\n |-- petal.width: string (nullable = true)\n |-- variety: string (nullable = true)\n\n"},{"execution_count":219,"output_type":"execute_result","data":{"text/plain":"class org.apache.spark.sql.Dataset"},"metadata":{}}]},{"cell_type":"markdown","source":"The fields are all Strings. We rename the fields and convert the numeric fields to doubles","metadata":{}},{"cell_type":"code","source":"var data = txt\n .withColumn(\"sepalLength\", txt.col(\"`sepal.length`\").cast(\"double\"))\n .withColumn(\"sepalWidth\", txt.col(\"`sepal.width`\").cast(\"double\"))\n .withColumn(\"petalLength\", txt.col(\"`petal.length`\").cast(\"double\"))\n .withColumn(\"petalWidth\", txt.col(\"`petal.width`\").cast(\"double\"))\n .withColumn(\"label\", txt.col(\"variety\"))\n .drop(\"sepal.length\", \"sepal.width\",\"petal.length\",\"petal.width\",\"variety\")\n","metadata":{"trusted":true},"execution_count":220,"outputs":[{"execution_count":220,"output_type":"execute_result","data":{"text/plain":"[sepalLength: double, sepalWidth: double ... 3 more fields]"},"metadata":{}}]},{"cell_type":"code","source":"data.printSchema()","metadata":{"trusted":true},"execution_count":221,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepalLength: double (nullable = true)\n |-- sepalWidth: double (nullable = true)\n |-- petalLength: double (nullable = true)\n |-- petalWidth: double (nullable = true)\n |-- label: string (nullable = true)\n\n"}]},{"cell_type":"code","source":"data.show","metadata":{"trusted":true},"execution_count":222,"outputs":[{"name":"stdout","output_type":"stream","text":"+-----------+----------+-----------+----------+------+\n|sepalLength|sepalWidth|petalLength|petalWidth| label|\n+-----------+----------+-----------+----------+------+\n| 5.1| 3.5| 1.4| 0.2|Setosa|\n| 4.9| 3.0| 1.4| 0.2|Setosa|\n| 4.7| 3.2| 1.3| 0.2|Setosa|\n| 4.6| 3.1| 1.5| 0.2|Setosa|\n| 5.0| 3.6| 1.4| 0.2|Setosa|\n| 5.4| 3.9| 1.7| 0.4|Setosa|\n| 4.6| 3.4| 1.4| 0.3|Setosa|\n| 5.0| 3.4| 1.5| 0.2|Setosa|\n| 4.4| 2.9| 1.4| 0.2|Setosa|\n| 4.9| 3.1| 1.5| 0.1|Setosa|\n| 5.4| 3.7| 1.5| 0.2|Setosa|\n| 4.8| 3.4| 1.6| 0.2|Setosa|\n| 4.8| 3.0| 1.4| 0.1|Setosa|\n| 4.3| 3.0| 1.1| 0.1|Setosa|\n| 5.8| 4.0| 1.2| 0.2|Setosa|\n| 5.7| 4.4| 1.5| 0.4|Setosa|\n| 5.4| 3.9| 1.3| 0.4|Setosa|\n| 5.1| 3.5| 1.4| 0.3|Setosa|\n| 5.7| 3.8| 1.7| 0.3|Setosa|\n| 5.1| 3.8| 1.5| 0.3|Setosa|\n+-----------+----------+-----------+----------+------+\nonly showing top 20 rows\n\n"}]},{"cell_type":"markdown","source":"The features need to be made available as vector. We use the VectorAssembler to add the feature (array) and we remove the individual feature fields.","metadata":{}},{"cell_type":"code","source":"import org.apache.spark.ml.feature.VectorAssembler\n\nval assembler = new VectorAssembler()\n .setInputCols(Array(\"sepalLength\", \"sepalWidth\", \"petalLength\",\"petalWidth\"))\n .setOutputCol(\"features\")\n\nval dataWithVector = assembler.transform(data).drop(\"sepalLength\", \"sepalWidth\", \"petalLength\",\"petalWidth\")\n","metadata":{"trusted":true},"execution_count":223,"outputs":[{"execution_count":223,"output_type":"execute_result","data":{"text/plain":"[label: string, features: vector]"},"metadata":{}}]},{"cell_type":"markdown","source":"We have the data now in the format that is needed by MLib with a label and vectorized features:","metadata":{}},{"cell_type":"code","source":"dataWithVector.show","metadata":{"trusted":true},"execution_count":224,"outputs":[{"name":"stdout","output_type":"stream","text":"+------+-----------------+\n| label| features|\n+------+-----------------+\n|Setosa|[5.1,3.5,1.4,0.2]|\n|Setosa|[4.9,3.0,1.4,0.2]|\n|Setosa|[4.7,3.2,1.3,0.2]|\n|Setosa|[4.6,3.1,1.5,0.2]|\n|Setosa|[5.0,3.6,1.4,0.2]|\n|Setosa|[5.4,3.9,1.7,0.4]|\n|Setosa|[4.6,3.4,1.4,0.3]|\n|Setosa|[5.0,3.4,1.5,0.2]|\n|Setosa|[4.4,2.9,1.4,0.2]|\n|Setosa|[4.9,3.1,1.5,0.1]|\n|Setosa|[5.4,3.7,1.5,0.2]|\n|Setosa|[4.8,3.4,1.6,0.2]|\n|Setosa|[4.8,3.0,1.4,0.1]|\n|Setosa|[4.3,3.0,1.1,0.1]|\n|Setosa|[5.8,4.0,1.2,0.2]|\n|Setosa|[5.7,4.4,1.5,0.4]|\n|Setosa|[5.4,3.9,1.3,0.4]|\n|Setosa|[5.1,3.5,1.4,0.3]|\n|Setosa|[5.7,3.8,1.7,0.3]|\n|Setosa|[5.1,3.8,1.5,0.3]|\n+------+-----------------+\nonly showing top 20 rows\n\n"}]},{"cell_type":"markdown","source":"We split the data into training and test sets (30% held out for testing).\n","metadata":{}},{"cell_type":"code","source":"val Array(trainingData, testData) = dataWithVector.randomSplit(Array(0.7, 0.3))\n","metadata":{"trusted":true},"execution_count":258,"outputs":[{"execution_count":258,"output_type":"execute_result","data":{"text/plain":"[label: string, features: vector]"},"metadata":{}}]},{"cell_type":"markdown","source":"### Train and Predict the Data ###\n\nThe labels are still Strings and we need to convert them to a numeric value. We do this with the StringIndexer().\nWe classify the data with a RandomForestClassifier() and convert the predicted data back to a string with IndexToString().\n\nAll these steps are collected in a Pipline which we use to fit and to predict (transform):","metadata":{}},{"cell_type":"code","source":"import org.apache.spark.ml.Pipeline\nimport org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}\nimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\nimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}\n\n// Index labels, adding metadata to the label column.\n// Fit on whole dataset to include all labels in index.\nval labelIndexer = new StringIndexer()\n .setInputCol(\"label\")\n .setOutputCol(\"indexedLabel\")\n .fit(data)\n\n// Train a RandomForest model.\nval rf = new RandomForestClassifier()\n .setLabelCol(\"indexedLabel\")\n .setFeaturesCol(\"features\")\n .setNumTrees(10)\n\n// Convert indexed labels back to original labels.\nval labelConverter = new IndexToString()\n .setInputCol(\"prediction\")\n .setOutputCol(\"predictedLabel\")\n .setLabels(labelIndexer.labels)\n\n// Chain indexers and forest in a Pipeline.\nval pipeline = new Pipeline()\n .setStages(Array(labelIndexer, rf, labelConverter))\n\n// Train model. This also runs the indexers.\nval model = pipeline.fit(trainingData)\n\n// Make predictions.\nval predictions = model.transform(testData)\n\n//predictions.show\npredictions.select(\"predictedLabel\", \"label\", \"features\").show(5)\n","metadata":{"trusted":true},"execution_count":259,"outputs":[{"name":"stdout","output_type":"stream","text":"+--------------+------+-----------------+\n|predictedLabel| label| features|\n+--------------+------+-----------------+\n| Setosa|Setosa|[4.4,3.2,1.3,0.2]|\n| Setosa|Setosa|[4.6,3.6,1.0,0.2]|\n| Setosa|Setosa|[4.7,3.2,1.3,0.2]|\n| Setosa|Setosa|[4.8,3.0,1.4,0.1]|\n| Setosa|Setosa|[4.8,3.0,1.4,0.3]|\n+--------------+------+-----------------+\nonly showing top 5 rows\n\n"},{"execution_count":259,"output_type":"execute_result","data":{"text/plain":"null"},"metadata":{}}]},{"cell_type":"markdown","source":"### Evaluation ###\nFinally we determine the accuracy","metadata":{}},{"cell_type":"code","source":"// Select (prediction, true label) and compute test error.\nval evaluator = new MulticlassClassificationEvaluator()\n .setLabelCol(\"indexedLabel\")\n .setPredictionCol(\"prediction\")\n .setMetricName(\"accuracy\")\nval accuracy = evaluator.evaluate(predictions)","metadata":{"trusted":true},"execution_count":227,"outputs":[{"execution_count":227,"output_type":"execute_result","data":{"text/plain":"0.9285714285714286"},"metadata":{}}]},{"cell_type":"markdown","source":"And we can print the forest model:","metadata":{}},{"cell_type":"code","source":"val rfModel = model.stages(1).asInstanceOf[RandomForestClassificationModel]\nprintln(s\"Learned classification forest model:\\n ${rfModel.toDebugString}\")\n","metadata":{"trusted":true},"execution_count":210,"outputs":[{"name":"stdout","output_type":"stream","text":"Learned classification forest model:\n RandomForestClassificationModel (uid=rfc_e4770982908d) with 10 trees\n Tree 0 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 3 <= 1.55)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Else (feature 3 > 1.55)\n If (feature 0 <= 6.05)\n If (feature 1 <= 3.05)\n Predict: 0.0\n Else (feature 1 > 3.05)\n Predict: 1.0\n Else (feature 0 > 6.05)\n Predict: 0.0\n Tree 1 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 3 <= 1.65)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Else (feature 3 > 1.65)\n Predict: 0.0\n Tree 2 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 2 <= 4.85)\n Predict: 1.0\n Else (feature 2 > 4.85)\n If (feature 3 <= 1.75)\n If (feature 3 <= 1.55)\n Predict: 0.0\n Else (feature 3 > 1.55)\n Predict: 1.0\n Else (feature 3 > 1.75)\n Predict: 0.0\n Tree 3 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 3 <= 1.55)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Else (feature 3 > 1.55)\n Predict: 0.0\n Tree 4 (weight 1.0):\n If (feature 2 <= 2.45)\n Predict: 2.0\n Else (feature 2 > 2.45)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Tree 5 (weight 1.0):\n If (feature 2 <= 2.45)\n Predict: 2.0\n Else (feature 2 > 2.45)\n If (feature 2 <= 4.55)\n Predict: 1.0\n Else (feature 2 > 4.55)\n If (feature 3 <= 1.75)\n If (feature 2 <= 5.05)\n If (feature 1 <= 2.25)\n Predict: 0.0\n Else (feature 1 > 2.25)\n Predict: 1.0\n Else (feature 2 > 5.05)\n Predict: 0.0\n Else (feature 3 > 1.75)\n Predict: 0.0\n Tree 6 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 3 <= 1.55)\n If (feature 0 <= 5.75)\n Predict: 1.0\n Else (feature 0 > 5.75)\n If (feature 1 <= 2.6500000000000004)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Else (feature 1 > 2.6500000000000004)\n Predict: 1.0\n Else (feature 3 > 1.55)\n If (feature 0 <= 5.95)\n If (feature 0 <= 5.85)\n Predict: 0.0\n Else (feature 0 > 5.85)\n Predict: 1.0\n Else (feature 0 > 5.95)\n Predict: 0.0\n Tree 7 (weight 1.0):\n If (feature 0 <= 5.45)\n If (feature 2 <= 2.45)\n Predict: 2.0\n Else (feature 2 > 2.45)\n Predict: 1.0\n Else (feature 0 > 5.45)\n If (feature 2 <= 4.75)\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n Predict: 1.0\n Else (feature 2 > 4.75)\n If (feature 3 <= 1.75)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n If (feature 3 <= 1.55)\n Predict: 0.0\n Else (feature 3 > 1.55)\n Predict: 1.0\n Else (feature 3 > 1.75)\n If (feature 0 <= 5.95)\n If (feature 0 <= 5.85)\n Predict: 0.0\n Else (feature 0 > 5.85)\n Predict: 0.0\n Else (feature 0 > 5.95)\n Predict: 0.0\n Tree 8 (weight 1.0):\n If (feature 3 <= 0.8)\n Predict: 2.0\n Else (feature 3 > 0.8)\n If (feature 3 <= 1.75)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n If (feature 0 <= 6.15)\n Predict: 0.0\n Else (feature 0 > 6.15)\n If (feature 0 <= 6.75)\n Predict: 1.0\n Else (feature 0 > 6.75)\n Predict: 0.0\n Else (feature 3 > 1.75)\n Predict: 0.0\n Tree 9 (weight 1.0):\n If (feature 2 <= 2.45)\n Predict: 2.0\n Else (feature 2 > 2.45)\n If (feature 0 <= 5.75)\n Predict: 1.0\n Else (feature 0 > 5.75)\n If (feature 1 <= 3.1500000000000004)\n If (feature 2 <= 4.75)\n Predict: 1.0\n Else (feature 2 > 4.75)\n If (feature 0 <= 6.25)\n Predict: 0.0\n Else (feature 0 > 6.25)\n Predict: 0.0\n Else (feature 1 > 3.1500000000000004)\n Predict: 0.0\n\n"},{"execution_count":210,"output_type":"execute_result","data":{"text/plain":"null"},"metadata":{}}]}]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment