Created
November 1, 2018 04:44
-
-
Save pschatzmann/72d13ebfa64472cf7986659c27788516 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"metadata":{"kernelspec":{"display_name":"Scala","language":"scala","name":"scala"},"language_info":{"codemirror_mode":"text/x-scala","file_extension":".scala","mimetype":"","name":"Scala","nbconverter_exporter":"","version":"2.11.12"}},"nbformat_minor":2,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# RandomForrest with MLib \n\nIn this document we show how to use Spark's MLib Machine Learning functionality in Scala using Jupyter with the BeakerX kernel (http://beakerx.com/).\n\n## Setup ##","metadata":{}},{"cell_type":"code","source":"%%classpath add mvn \norg.apache.spark:spark-sql_2.11:2.3.2\norg.apache.spark:spark-mllib_2.11:2.3.2\n","metadata":{"trusted":true},"execution_count":198,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"","version_major":2,"version_minor":0},"method":"display_data"},"metadata":{}}]},{"cell_type":"markdown","source":"Spark is a little bit too chatty. So we just want to see the error messages:","metadata":{}},{"cell_type":"code","source":"org.apache.log4j.Logger.getRootLogger().setLevel(org.apache.log4j.Level.ERROR);\n","metadata":{"trusted":true},"execution_count":214,"outputs":[]},{"cell_type":"code","source":"import org.apache.spark.sql.SparkSession\n\nval spark = SparkSession.builder()\n .appName(\"Simple Application\")\n .master(\"local[4]\")\n .config(\"spark.ui.enabled\", \"false\")\n .getOrCreate()","metadata":{"trusted":true},"execution_count":215,"outputs":[{"execution_count":215,"output_type":"execute_result","data":{"text/plain":"org.apache.spark.sql.SparkSession@3807c582"},"metadata":{}}]},{"cell_type":"markdown","source":"## Data Preparation ##\n\nTo use the data in MLib we need to have the features in a Vector. So first we need to get and pre-process the data so that we have the requested format.\n\nIn Spark it is not possible to read from a URL. Therefore we download the data from the internet and create a file first.","metadata":{}},{"cell_type":"code","source":"import java.io.File\nimport java.io.PrintWriter\nimport scala.io.Source\n\nvar file = new File(\"iris.csv\")\nif (!file.exists()) {\n val url = \"https://gist.githubusercontent.com/netj/8836201/raw/6f9306ad21398ea43cba4f7d537619d0e07d5ae3/iris.csv\"\n val csv = Source.fromURL(url).mkString\n val writer = new PrintWriter(file)\n writer.write(csv)\n writer.close()\n}\n\nfile.exists","metadata":{"trusted":true},"execution_count":228,"outputs":[{"execution_count":228,"output_type":"execute_result","data":{"text/plain":"true"},"metadata":{}}]},{"cell_type":"markdown","source":"Now we can read the file into a spark.sql.Dataset","metadata":{}},{"cell_type":"code","source":"val txt = spark.read\n .format(\"csv\")\n .option(\"header\", \"true\") \n .option(\"mode\", \"DROPMALFORMED\")\n .load(file.toString())\n\ntxt.printSchema()\ntxt.getClass()","metadata":{"trusted":true},"execution_count":219,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepal.length: string (nullable = true)\n |-- sepal.width: string (nullable = true)\n |-- petal.length: string (nullable = true)\n |-- petal.width: string (nullable = true)\n |-- variety: string (nullable = true)\n\n"},{"execution_count":219,"output_type":"execute_result","data":{"text/plain":"class org.apache.spark.sql.Dataset"},"metadata":{}}]},{"cell_type":"markdown","source":"The fields are all Strings. We rename the fields and convert the numeric fields to doubles","metadata":{}},{"cell_type":"code","source":"var data = txt\n .withColumn(\"sepalLength\", txt.col(\"`sepal.length`\").cast(\"double\"))\n .withColumn(\"sepalWidth\", txt.col(\"`sepal.width`\").cast(\"double\"))\n .withColumn(\"petalLength\", txt.col(\"`petal.length`\").cast(\"double\"))\n .withColumn(\"petalWidth\", txt.col(\"`petal.width`\").cast(\"double\"))\n .withColumn(\"label\", txt.col(\"variety\"))\n .drop(\"sepal.length\", \"sepal.width\",\"petal.length\",\"petal.width\",\"variety\")\n","metadata":{"trusted":true},"execution_count":220,"outputs":[{"execution_count":220,"output_type":"execute_result","data":{"text/plain":"[sepalLength: double, sepalWidth: double ... 3 more fields]"},"metadata":{}}]},{"cell_type":"code","source":"data.printSchema()","metadata":{"trusted":true},"execution_count":221,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepalLength: double (nullable = true)\n |-- sepalWidth: double (nullable = true)\n |-- petalLength: double (nullable = true)\n |-- petalWidth: double (nullable = true)\n |-- label: string (nullable = true)\n\n"}]},{"cell_type":"code","source":"data.show","metadata":{"trusted":true},"execution_count":222,"outputs":[{"name":"stdout","output_type":"stream","text":"+-----------+----------+-----------+----------+------+\n|sepalLength|sepalWidth|petalLength|petalWidth| label|\n+-----------+----------+-----------+----------+------+\n| 5.1| 3.5| 1.4| 0.2|Setosa|\n| 4.9| 3.0| 1.4| 0.2|Setosa|\n| 4.7| 3.2| 1.3| 0.2|Setosa|\n| 4.6| 3.1| 1.5| 0.2|Setosa|\n| 5.0| 3.6| 1.4| 0.2|Setosa|\n| 5.4| 3.9| 1.7| 0.4|Setosa|\n| 4.6| 3.4| 1.4| 0.3|Setosa|\n| 5.0| 3.4| 1.5| 0.2|Setosa|\n| 4.4| 2.9| 1.4| 0.2|Setosa|\n| 4.9| 3.1| 1.5| 0.1|Setosa|\n| 5.4| 3.7| 1.5| 0.2|Setosa|\n| 4.8| 3.4| 1.6| 0.2|Setosa|\n| 4.8| 3.0| 1.4| 0.1|Setosa|\n| 4.3| 3.0| 1.1| 0.1|Setosa|\n| 5.8| 4.0| 1.2| 0.2|Setosa|\n| 5.7| 4.4| 1.5| 0.4|Setosa|\n| 5.4| 3.9| 1.3| 0.4|Setosa|\n| 5.1| 3.5| 1.4| 0.3|Setosa|\n| 5.7| 3.8| 1.7| 0.3|Setosa|\n| 5.1| 3.8| 1.5| 0.3|Setosa|\n+-----------+----------+-----------+----------+------+\nonly showing top 20 rows\n\n"}]},{"cell_type":"markdown","source":"The features need to be made available as vector. We use the VectorAssembler to add the feature (array) and we remove the individual feature fields.","metadata":{}},{"cell_type":"code","source":"import org.apache.spark.ml.feature.VectorAssembler\n\nval assembler = new VectorAssembler()\n .setInputCols(Array(\"sepalLength\", \"sepalWidth\", \"petalLength\",\"petalWidth\"))\n .setOutputCol(\"features\")\n\nval dataWithVector = assembler.transform(data).drop(\"sepalLength\", \"sepalWidth\", \"petalLength\",\"petalWidth\")\n","metadata":{"trusted":true},"execution_count":223,"outputs":[{"execution_count":223,"output_type":"execute_result","data":{"text/plain":"[label: string, features: vector]"},"metadata":{}}]},{"cell_type":"markdown","source":"We have the data now in the format that is needed by MLib with a label and vectorized features:","metadata":{}},{"cell_type":"code","source":"dataWithVector.show","metadata":{"trusted":true},"execution_count":224,"outputs":[{"name":"stdout","output_type":"stream","text":"+------+-----------------+\n| label| features|\n+------+-----------------+\n|Setosa|[5.1,3.5,1.4,0.2]|\n|Setosa|[4.9,3.0,1.4,0.2]|\n|Setosa|[4.7,3.2,1.3,0.2]|\n|Setosa|[4.6,3.1,1.5,0.2]|\n|Setosa|[5.0,3.6,1.4,0.2]|\n|Setosa|[5.4,3.9,1.7,0.4]|\n|Setosa|[4.6,3.4,1.4,0.3]|\n|Setosa|[5.0,3.4,1.5,0.2]|\n|Setosa|[4.4,2.9,1.4,0.2]|\n|Setosa|[4.9,3.1,1.5,0.1]|\n|Setosa|[5.4,3.7,1.5,0.2]|\n|Setosa|[4.8,3.4,1.6,0.2]|\n|Setosa|[4.8,3.0,1.4,0.1]|\n|Setosa|[4.3,3.0,1.1,0.1]|\n|Setosa|[5.8,4.0,1.2,0.2]|\n|Setosa|[5.7,4.4,1.5,0.4]|\n|Setosa|[5.4,3.9,1.3,0.4]|\n|Setosa|[5.1,3.5,1.4,0.3]|\n|Setosa|[5.7,3.8,1.7,0.3]|\n|Setosa|[5.1,3.8,1.5,0.3]|\n+------+-----------------+\nonly showing top 20 rows\n\n"}]},{"cell_type":"markdown","source":"We split the data into training and test sets (30% held out for testing).\n","metadata":{}},{"cell_type":"code","source":"val Array(trainingData, testData) = dataWithVector.randomSplit(Array(0.7, 0.3))\n","metadata":{"trusted":true},"execution_count":258,"outputs":[{"execution_count":258,"output_type":"execute_result","data":{"text/plain":"[label: string, features: vector]"},"metadata":{}}]},{"cell_type":"markdown","source":"### Train and Predict the Data ###\n\nThe labels are still Strings and we need to convert them to a numeric value. We do this with the StringIndexer().\nWe classify the data with a RandomForestClassifier() and convert the predicted data back to a string with IndexToString().\n\nAll these steps are collected in a Pipline which we use to fit and to predict (transform):","metadata":{}},{"cell_type":"code","source":"import org.apache.spark.ml.Pipeline\nimport org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}\nimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\nimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}\n\n// Index labels, adding metadata to the label column.\n// Fit on whole dataset to include all labels in index.\nval labelIndexer = new StringIndexer()\n .setInputCol(\"label\")\n .setOutputCol(\"indexedLabel\")\n .fit(data)\n\n// Train a RandomForest model.\nval rf = new RandomForestClassifier()\n .setLabelCol(\"indexedLabel\")\n .setFeaturesCol(\"features\")\n .setNumTrees(10)\n\n// Convert indexed labels back to original labels.\nval labelConverter = new IndexToString()\n .setInputCol(\"prediction\")\n .setOutputCol(\"predictedLabel\")\n .setLabels(labelIndexer.labels)\n\n// Chain indexers and forest in a Pipeline.\nval pipeline = new Pipeline()\n .setStages(Array(labelIndexer, rf, labelConverter))\n\n// Train model. This also runs the indexers.\nval model = pipeline.fit(trainingData)\n\n// Make predictions.\nval predictions = model.transform(testData)\n\n//predictions.show\npredictions.select(\"predictedLabel\", \"label\", \"features\").show(5)\n","metadata":{"trusted":true},"execution_count":259,"outputs":[{"name":"stdout","output_type":"stream","text":"+--------------+------+-----------------+\n|predictedLabel| label| features|\n+--------------+------+-----------------+\n| Setosa|Setosa|[4.4,3.2,1.3,0.2]|\n| Setosa|Setosa|[4.6,3.6,1.0,0.2]|\n| Setosa|Setosa|[4.7,3.2,1.3,0.2]|\n| Setosa|Setosa|[4.8,3.0,1.4,0.1]|\n| Setosa|Setosa|[4.8,3.0,1.4,0.3]|\n+--------------+------+-----------------+\nonly showing top 5 rows\n\n"},{"execution_count":259,"output_type":"execute_result","data":{"text/plain":"null"},"metadata":{}}]},{"cell_type":"markdown","source":"### Evaluation ###\nFinally we determine the accuracy","metadata":{}},{"cell_type":"code","source":"// Select (prediction, true label) and compute test error.\nval evaluator = new MulticlassClassificationEvaluator()\n .setLabelCol(\"indexedLabel\")\n .setPredictionCol(\"prediction\")\n .setMetricName(\"accuracy\")\nval accuracy = evaluator.evaluate(predictions)","metadata":{"trusted":true},"execution_count":227,"outputs":[{"execution_count":227,"output_type":"execute_result","data":{"text/plain":"0.9285714285714286"},"metadata":{}}]},{"cell_type":"markdown","source":"And we can print the forest model:","metadata":{}},{"cell_type":"code","source":"val rfModel = model.stages(1).asInstanceOf[RandomForestClassificationModel]\nprintln(s\"Learned classification forest model:\\n ${rfModel.toDebugString}\")\n","metadata":{"trusted":true},"execution_count":210,"outputs":[{"name":"stdout","output_type":"stream","text":"Learned classification forest model:\n RandomForestClassificationModel (uid=rfc_e4770982908d) with 10 trees\n Tree 0 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 3 <= 1.55)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Else (feature 3 > 1.55)\n If (feature 0 <= 6.05)\n If (feature 1 <= 3.05)\n Predict: 0.0\n Else (feature 1 > 3.05)\n Predict: 1.0\n Else (feature 0 > 6.05)\n Predict: 0.0\n Tree 1 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 3 <= 1.65)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Else (feature 3 > 1.65)\n Predict: 0.0\n Tree 2 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 2 <= 4.85)\n Predict: 1.0\n Else (feature 2 > 4.85)\n If (feature 3 <= 1.75)\n If (feature 3 <= 1.55)\n Predict: 0.0\n Else (feature 3 > 1.55)\n Predict: 1.0\n Else (feature 3 > 1.75)\n Predict: 0.0\n Tree 3 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 3 <= 1.55)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Else (feature 3 > 1.55)\n Predict: 0.0\n Tree 4 (weight 1.0):\n If (feature 2 <= 2.45)\n Predict: 2.0\n Else (feature 2 > 2.45)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Tree 5 (weight 1.0):\n If (feature 2 <= 2.45)\n Predict: 2.0\n Else (feature 2 > 2.45)\n If (feature 2 <= 4.55)\n Predict: 1.0\n Else (feature 2 > 4.55)\n If (feature 3 <= 1.75)\n If (feature 2 <= 5.05)\n If (feature 1 <= 2.25)\n Predict: 0.0\n Else (feature 1 > 2.25)\n Predict: 1.0\n Else (feature 2 > 5.05)\n Predict: 0.0\n Else (feature 3 > 1.75)\n Predict: 0.0\n Tree 6 (weight 1.0):\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n If (feature 3 <= 1.55)\n If (feature 0 <= 5.75)\n Predict: 1.0\n Else (feature 0 > 5.75)\n If (feature 1 <= 2.6500000000000004)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n Predict: 0.0\n Else (feature 1 > 2.6500000000000004)\n Predict: 1.0\n Else (feature 3 > 1.55)\n If (feature 0 <= 5.95)\n If (feature 0 <= 5.85)\n Predict: 0.0\n Else (feature 0 > 5.85)\n Predict: 1.0\n Else (feature 0 > 5.95)\n Predict: 0.0\n Tree 7 (weight 1.0):\n If (feature 0 <= 5.45)\n If (feature 2 <= 2.45)\n Predict: 2.0\n Else (feature 2 > 2.45)\n Predict: 1.0\n Else (feature 0 > 5.45)\n If (feature 2 <= 4.75)\n If (feature 3 <= 0.5)\n Predict: 2.0\n Else (feature 3 > 0.5)\n Predict: 1.0\n Else (feature 2 > 4.75)\n If (feature 3 <= 1.75)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n If (feature 3 <= 1.55)\n Predict: 0.0\n Else (feature 3 > 1.55)\n Predict: 1.0\n Else (feature 3 > 1.75)\n If (feature 0 <= 5.95)\n If (feature 0 <= 5.85)\n Predict: 0.0\n Else (feature 0 > 5.85)\n Predict: 0.0\n Else (feature 0 > 5.95)\n Predict: 0.0\n Tree 8 (weight 1.0):\n If (feature 3 <= 0.8)\n Predict: 2.0\n Else (feature 3 > 0.8)\n If (feature 3 <= 1.75)\n If (feature 2 <= 4.95)\n Predict: 1.0\n Else (feature 2 > 4.95)\n If (feature 0 <= 6.15)\n Predict: 0.0\n Else (feature 0 > 6.15)\n If (feature 0 <= 6.75)\n Predict: 1.0\n Else (feature 0 > 6.75)\n Predict: 0.0\n Else (feature 3 > 1.75)\n Predict: 0.0\n Tree 9 (weight 1.0):\n If (feature 2 <= 2.45)\n Predict: 2.0\n Else (feature 2 > 2.45)\n If (feature 0 <= 5.75)\n Predict: 1.0\n Else (feature 0 > 5.75)\n If (feature 1 <= 3.1500000000000004)\n If (feature 2 <= 4.75)\n Predict: 1.0\n Else (feature 2 > 4.75)\n If (feature 0 <= 6.25)\n Predict: 0.0\n Else (feature 0 > 6.25)\n Predict: 0.0\n Else (feature 1 > 3.1500000000000004)\n Predict: 0.0\n\n"},{"execution_count":210,"output_type":"execute_result","data":{"text/plain":"null"},"metadata":{}}]}]} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment