Created
November 2, 2018 12:29
-
-
Save pschatzmann/a8a3741c20d6a063d7dc6f9d82f9c29b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"metadata":{"kernelspec":{"display_name":"Scala","language":"scala","name":"scala"},"language_info":{"codemirror_mode":"text/x-scala","file_extension":".scala","mimetype":"","name":"Scala","nbconverter_exporter":"","version":"2.11.12"}},"nbformat_minor":2,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# RandomForestClassifier with MLib \n\nIn this document we show how to use Spark's MLib Machine Learning functionality using the RandomForest classifier on the IRIS data. Most of the examples that I have found are loading the data from a file. In my example I load it from the Internet, so the only thing you need to execute the example is a working network connection.\n\n## Setup ","metadata":{}},{"cell_type":"code","source":"%%classpath add mvn \norg.apache.spark:spark-sql_2.11:2.3.2\norg.apache.spark:spark-mllib_2.11:2.3.2\n","metadata":{"trusted":true},"execution_count":57,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"","version_major":2,"version_minor":0},"method":"display_data"},"metadata":{}}]},{"cell_type":"code","source":"import org.apache.spark.sql.SparkSession\n\nval spark = SparkSession.builder()\n .appName(\"Iris NaiveBayes\")\n .master(\"local\")\n .config(\"spark.ui.enabled\", \"false\")\n .getOrCreate()\n","metadata":{"trusted":true},"execution_count":58,"outputs":[{"execution_count":58,"output_type":"execute_result","data":{"text/plain":"org.apache.spark.sql.SparkSession@36de8728"},"metadata":{}}]},{"cell_type":"markdown","source":"## Data Preparation\nWe load the data form the Internet with the help of a url.","metadata":{}},{"cell_type":"code","source":"import java.net.URL\nimport scala.io.Source\nimport spark.implicits._\n\nvar url = \"https://gist.githubusercontent.com/netj/8836201/raw/6f9306ad21398ea43cba4f7d537619d0e07d5ae3/iris.csv\"\nval streamString = Source.fromURL(new URL(url)).mkString\nval csvList = streamString.lines.toList\n\nval in = spark.read\n .option(\"header\", \"true\")\n .option(\"inferSchema\", \"true\")\n .csv(csvList.toDS())\n\nin.printSchema()\n","metadata":{"trusted":true},"execution_count":59,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepal.length: double (nullable = true)\n |-- sepal.width: double (nullable = true)\n |-- petal.length: double (nullable = true)\n |-- petal.width: double (nullable = true)\n |-- variety: string (nullable = true)\n\n"},{"execution_count":59,"output_type":"execute_result","data":{"text/plain":"org.apache.spark.sql.SparkSession$implicits$@723693c9"},"metadata":{}}]},{"cell_type":"markdown","source":"We rename the fiels because the dot is creating issues. We could escape the field names with ´ but to work with proper field names is preferrable:","metadata":{}},{"cell_type":"code","source":"var data = in\n .withColumn(\"sepalLength\", in.col(\"`sepal.length`\").cast(\"double\"))\n .withColumn(\"sepalWidth\", in.col(\"`sepal.width`\").cast(\"double\"))\n .withColumn(\"petalLength\", in.col(\"`petal.length`\").cast(\"double\"))\n .withColumn(\"petalWidth\", in.col(\"`petal.width`\").cast(\"double\"))\n .withColumn(\"label\", in.col(\"variety\"))\n .drop(\"sepal.length\", \"sepal.width\",\"petal.length\",\"petal.width\",\"variety\")\n\n\ndata.printSchema()\ndata.show()\n","metadata":{"trusted":true},"execution_count":66,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepalLength: double (nullable = true)\n |-- sepalWidth: double (nullable = true)\n |-- petalLength: double (nullable = true)\n |-- petalWidth: double (nullable = true)\n |-- label: string (nullable = true)\n\n+-----------+----------+-----------+----------+------+\n|sepalLength|sepalWidth|petalLength|petalWidth| label|\n+-----------+----------+-----------+----------+------+\n| 5.1| 3.5| 1.4| 0.2|Setosa|\n| 4.9| 3.0| 1.4| 0.2|Setosa|\n| 4.7| 3.2| 1.3| 0.2|Setosa|\n| 4.6| 3.1| 1.5| 0.2|Setosa|\n| 5.0| 3.6| 1.4| 0.2|Setosa|\n| 5.4| 3.9| 1.7| 0.4|Setosa|\n| 4.6| 3.4| 1.4| 0.3|Setosa|\n| 5.0| 3.4| 1.5| 0.2|Setosa|\n| 4.4| 2.9| 1.4| 0.2|Setosa|\n| 4.9| 3.1| 1.5| 0.1|Setosa|\n| 5.4| 3.7| 1.5| 0.2|Setosa|\n| 4.8| 3.4| 1.6| 0.2|Setosa|\n| 4.8| 3.0| 1.4| 0.1|Setosa|\n| 4.3| 3.0| 1.1| 0.1|Setosa|\n| 5.8| 4.0| 1.2| 0.2|Setosa|\n| 5.7| 4.4| 1.5| 0.4|Setosa|\n| 5.4| 3.9| 1.3| 0.4|Setosa|\n| 5.1| 3.5| 1.4| 0.3|Setosa|\n| 5.7| 3.8| 1.7| 0.3|Setosa|\n| 5.1| 3.8| 1.5| 0.3|Setosa|\n+-----------+----------+-----------+----------+------+\nonly showing top 20 rows\n\n"},{"execution_count":66,"output_type":"execute_result","data":{"text/plain":"null"},"metadata":{}}]},{"cell_type":"markdown","source":"Finally we split the data into a training and test dataset and this concludes our data preparation.","metadata":{}},{"cell_type":"code","source":"// Split the data into training and test sets (20% held out for testing)\nval Array(trainingData, testData) = data.randomSplit(Array(0.8, 0.2), seed = 1234L)\n","metadata":{"trusted":true},"execution_count":61,"outputs":[{"execution_count":61,"output_type":"execute_result","data":{"text/plain":"[sepalLength: double, sepalWidth: double ... 3 more fields]"},"metadata":{}}]},{"cell_type":"markdown","source":"### Train and Predict the Data ###\n\nThe features need to be vectorized. We can do this with the VectorAssembler. The labels are still Strings and we need to convert them to a numeric value. We do this with the StringIndexer.\nWe classify the data with a RandomForestClassifier and convert the predicted data back to a string with IndexToString.\n\nAll these steps are collected in a Pipline which we use to fit and to predict (transform):","metadata":{}},{"cell_type":"code","source":"import org.apache.spark.ml.Pipeline\nimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\nimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}\nimport org.apache.spark.ml.classification.RandomForestClassifier\nimport org.apache.spark.ml.feature.VectorAssembler\n\n// Build the Feature Vector\nval vectorAssembler = new VectorAssembler()\n .setInputCols(Array(\"sepalLength\", \"sepalWidth\", \"petalLength\",\"petalWidth\"))\n .setOutputCol(\"features\")\n\n// Index labels, adding metadata to the label column.\n// Fit on whole dataset to include all labels in index.\nval labelIndexer = new StringIndexer()\n .setInputCol(\"label\")\n .setOutputCol(\"indexedLabel\")\n .fit(data)\n\n// Train a RandomForest model.\nval classifier = new NaiveBayes()\n .setLabelCol(\"indexedLabel\")\n .setFeaturesCol(\"features\")\n\n// Convert indexed labels back to original labels.\nval labelConverter = new IndexToString()\n .setInputCol(\"prediction\")\n .setOutputCol(\"predictedLabel\")\n .setLabels(labelIndexer.labels)\n\n// Chain indexers and forest in a Pipeline.\nval pipeline = new Pipeline()\n .setStages(Array(vectorAssembler, labelIndexer, classifier, labelConverter))\n\n// Train model. This also runs the indexers.\nval model = pipeline.fit(trainingData)\n\n// Make predictions.\nval predictions = model.transform(testData)\n\n//predictions.show\npredictions.select(\"predictedLabel\", \"label\", \"features\").show(10)","metadata":{"trusted":true},"execution_count":62,"outputs":[{"name":"stdout","output_type":"stream","text":"+--------------+----------+-----------------+\n|predictedLabel| label| features|\n+--------------+----------+-----------------+\n| Setosa| Setosa|[4.3,3.0,1.1,0.1]|\n| Setosa| Setosa|[4.4,2.9,1.4,0.2]|\n| Setosa| Setosa|[4.4,3.0,1.3,0.2]|\n| Setosa| Setosa|[4.8,3.1,1.6,0.2]|\n| Setosa| Setosa|[5.0,3.3,1.4,0.2]|\n| Setosa| Setosa|[5.0,3.4,1.5,0.2]|\n| Setosa| Setosa|[5.0,3.6,1.4,0.2]|\n| Setosa| Setosa|[5.1,3.4,1.5,0.2]|\n| Versicolor|Versicolor|[5.2,2.7,3.9,1.4]|\n| Setosa| Setosa|[5.2,4.1,1.5,0.1]|\n+--------------+----------+-----------------+\nonly showing top 10 rows\n\n"},{"execution_count":62,"output_type":"execute_result","data":{"text/plain":"null"},"metadata":{}}]},{"cell_type":"markdown","source":"## Evaluation\n\nWe need to do the evaluation of the accuracy of our model on the numerical labels. Here is the current Schema:","metadata":{}},{"cell_type":"code","source":"predictions.printSchema","metadata":{"trusted":true},"execution_count":63,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepalLength: double (nullable = true)\n |-- sepalWidth: double (nullable = true)\n |-- petalLength: double (nullable = true)\n |-- petalWidth: double (nullable = true)\n |-- label: string (nullable = true)\n |-- features: vector (nullable = true)\n |-- indexedLabel: double (nullable = false)\n |-- rawPrediction: vector (nullable = true)\n |-- probability: vector (nullable = true)\n |-- prediction: double (nullable = false)\n |-- predictedLabel: string (nullable = true)\n\n"}]},{"cell_type":"code","source":"import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\n\n// Select (prediction, true label) and compute test error\nval evaluator = new MulticlassClassificationEvaluator()\n .setLabelCol(\"indexedLabel\")\n .setPredictionCol(\"prediction\")\n .setMetricName(\"accuracy\")\nval accuracy = evaluator.evaluate(predictions)\n","metadata":{"trusted":true},"execution_count":64,"outputs":[{"execution_count":64,"output_type":"execute_result","data":{"text/plain":"0.9583333333333334"},"metadata":{}}]},{"cell_type":"code","source":"","metadata":{"trusted":true},"execution_count":null,"outputs":[]}]} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment