Skip to content

Instantly share code, notes, and snippets.

@pschatzmann
Created November 2, 2018 12:29
Show Gist options
  • Save pschatzmann/a8a3741c20d6a063d7dc6f9d82f9c29b to your computer and use it in GitHub Desktop.
Save pschatzmann/a8a3741c20d6a063d7dc6f9d82f9c29b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{"metadata":{"kernelspec":{"display_name":"Scala","language":"scala","name":"scala"},"language_info":{"codemirror_mode":"text/x-scala","file_extension":".scala","mimetype":"","name":"Scala","nbconverter_exporter":"","version":"2.11.12"}},"nbformat_minor":2,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# RandomForestClassifier with MLib \n\nIn this document we show how to use Spark's MLib Machine Learning functionality using the RandomForest classifier on the IRIS data. Most of the examples that I have found are loading the data from a file. In my example I load it from the Internet, so the only thing you need to execute the example is a working network connection.\n\n## Setup ","metadata":{}},{"cell_type":"code","source":"%%classpath add mvn \norg.apache.spark:spark-sql_2.11:2.3.2\norg.apache.spark:spark-mllib_2.11:2.3.2\n","metadata":{"trusted":true},"execution_count":57,"outputs":[{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"","version_major":2,"version_minor":0},"method":"display_data"},"metadata":{}}]},{"cell_type":"code","source":"import org.apache.spark.sql.SparkSession\n\nval spark = SparkSession.builder()\n .appName(\"Iris NaiveBayes\")\n .master(\"local\")\n .config(\"spark.ui.enabled\", \"false\")\n .getOrCreate()\n","metadata":{"trusted":true},"execution_count":58,"outputs":[{"execution_count":58,"output_type":"execute_result","data":{"text/plain":"org.apache.spark.sql.SparkSession@36de8728"},"metadata":{}}]},{"cell_type":"markdown","source":"## Data Preparation\nWe load the data form the Internet with the help of a url.","metadata":{}},{"cell_type":"code","source":"import java.net.URL\nimport scala.io.Source\nimport spark.implicits._\n\nvar url = \"https://gist.githubusercontent.com/netj/8836201/raw/6f9306ad21398ea43cba4f7d537619d0e07d5ae3/iris.csv\"\nval streamString = Source.fromURL(new URL(url)).mkString\nval csvList = streamString.lines.toList\n\nval in = spark.read\n .option(\"header\", \"true\")\n .option(\"inferSchema\", \"true\")\n .csv(csvList.toDS())\n\nin.printSchema()\n","metadata":{"trusted":true},"execution_count":59,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepal.length: double (nullable = true)\n |-- sepal.width: double (nullable = true)\n |-- petal.length: double (nullable = true)\n |-- petal.width: double (nullable = true)\n |-- variety: string (nullable = true)\n\n"},{"execution_count":59,"output_type":"execute_result","data":{"text/plain":"org.apache.spark.sql.SparkSession$implicits$@723693c9"},"metadata":{}}]},{"cell_type":"markdown","source":"We rename the fiels because the dot is creating issues. We could escape the field names with ´ but to work with proper field names is preferrable:","metadata":{}},{"cell_type":"code","source":"var data = in\n .withColumn(\"sepalLength\", in.col(\"`sepal.length`\").cast(\"double\"))\n .withColumn(\"sepalWidth\", in.col(\"`sepal.width`\").cast(\"double\"))\n .withColumn(\"petalLength\", in.col(\"`petal.length`\").cast(\"double\"))\n .withColumn(\"petalWidth\", in.col(\"`petal.width`\").cast(\"double\"))\n .withColumn(\"label\", in.col(\"variety\"))\n .drop(\"sepal.length\", \"sepal.width\",\"petal.length\",\"petal.width\",\"variety\")\n\n\ndata.printSchema()\ndata.show()\n","metadata":{"trusted":true},"execution_count":66,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepalLength: double (nullable = true)\n |-- sepalWidth: double (nullable = true)\n |-- petalLength: double (nullable = true)\n |-- petalWidth: double (nullable = true)\n |-- label: string (nullable = true)\n\n+-----------+----------+-----------+----------+------+\n|sepalLength|sepalWidth|petalLength|petalWidth| label|\n+-----------+----------+-----------+----------+------+\n| 5.1| 3.5| 1.4| 0.2|Setosa|\n| 4.9| 3.0| 1.4| 0.2|Setosa|\n| 4.7| 3.2| 1.3| 0.2|Setosa|\n| 4.6| 3.1| 1.5| 0.2|Setosa|\n| 5.0| 3.6| 1.4| 0.2|Setosa|\n| 5.4| 3.9| 1.7| 0.4|Setosa|\n| 4.6| 3.4| 1.4| 0.3|Setosa|\n| 5.0| 3.4| 1.5| 0.2|Setosa|\n| 4.4| 2.9| 1.4| 0.2|Setosa|\n| 4.9| 3.1| 1.5| 0.1|Setosa|\n| 5.4| 3.7| 1.5| 0.2|Setosa|\n| 4.8| 3.4| 1.6| 0.2|Setosa|\n| 4.8| 3.0| 1.4| 0.1|Setosa|\n| 4.3| 3.0| 1.1| 0.1|Setosa|\n| 5.8| 4.0| 1.2| 0.2|Setosa|\n| 5.7| 4.4| 1.5| 0.4|Setosa|\n| 5.4| 3.9| 1.3| 0.4|Setosa|\n| 5.1| 3.5| 1.4| 0.3|Setosa|\n| 5.7| 3.8| 1.7| 0.3|Setosa|\n| 5.1| 3.8| 1.5| 0.3|Setosa|\n+-----------+----------+-----------+----------+------+\nonly showing top 20 rows\n\n"},{"execution_count":66,"output_type":"execute_result","data":{"text/plain":"null"},"metadata":{}}]},{"cell_type":"markdown","source":"Finally we split the data into a training and test dataset and this concludes our data preparation.","metadata":{}},{"cell_type":"code","source":"// Split the data into training and test sets (20% held out for testing)\nval Array(trainingData, testData) = data.randomSplit(Array(0.8, 0.2), seed = 1234L)\n","metadata":{"trusted":true},"execution_count":61,"outputs":[{"execution_count":61,"output_type":"execute_result","data":{"text/plain":"[sepalLength: double, sepalWidth: double ... 3 more fields]"},"metadata":{}}]},{"cell_type":"markdown","source":"### Train and Predict the Data ###\n\nThe features need to be vectorized. We can do this with the VectorAssembler. The labels are still Strings and we need to convert them to a numeric value. We do this with the StringIndexer.\nWe classify the data with a RandomForestClassifier and convert the predicted data back to a string with IndexToString.\n\nAll these steps are collected in a Pipline which we use to fit and to predict (transform):","metadata":{}},{"cell_type":"code","source":"import org.apache.spark.ml.Pipeline\nimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\nimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}\nimport org.apache.spark.ml.classification.RandomForestClassifier\nimport org.apache.spark.ml.feature.VectorAssembler\n\n// Build the Feature Vector\nval vectorAssembler = new VectorAssembler()\n .setInputCols(Array(\"sepalLength\", \"sepalWidth\", \"petalLength\",\"petalWidth\"))\n .setOutputCol(\"features\")\n\n// Index labels, adding metadata to the label column.\n// Fit on whole dataset to include all labels in index.\nval labelIndexer = new StringIndexer()\n .setInputCol(\"label\")\n .setOutputCol(\"indexedLabel\")\n .fit(data)\n\n// Train a RandomForest model.\nval classifier = new NaiveBayes()\n .setLabelCol(\"indexedLabel\")\n .setFeaturesCol(\"features\")\n\n// Convert indexed labels back to original labels.\nval labelConverter = new IndexToString()\n .setInputCol(\"prediction\")\n .setOutputCol(\"predictedLabel\")\n .setLabels(labelIndexer.labels)\n\n// Chain indexers and forest in a Pipeline.\nval pipeline = new Pipeline()\n .setStages(Array(vectorAssembler, labelIndexer, classifier, labelConverter))\n\n// Train model. This also runs the indexers.\nval model = pipeline.fit(trainingData)\n\n// Make predictions.\nval predictions = model.transform(testData)\n\n//predictions.show\npredictions.select(\"predictedLabel\", \"label\", \"features\").show(10)","metadata":{"trusted":true},"execution_count":62,"outputs":[{"name":"stdout","output_type":"stream","text":"+--------------+----------+-----------------+\n|predictedLabel| label| features|\n+--------------+----------+-----------------+\n| Setosa| Setosa|[4.3,3.0,1.1,0.1]|\n| Setosa| Setosa|[4.4,2.9,1.4,0.2]|\n| Setosa| Setosa|[4.4,3.0,1.3,0.2]|\n| Setosa| Setosa|[4.8,3.1,1.6,0.2]|\n| Setosa| Setosa|[5.0,3.3,1.4,0.2]|\n| Setosa| Setosa|[5.0,3.4,1.5,0.2]|\n| Setosa| Setosa|[5.0,3.6,1.4,0.2]|\n| Setosa| Setosa|[5.1,3.4,1.5,0.2]|\n| Versicolor|Versicolor|[5.2,2.7,3.9,1.4]|\n| Setosa| Setosa|[5.2,4.1,1.5,0.1]|\n+--------------+----------+-----------------+\nonly showing top 10 rows\n\n"},{"execution_count":62,"output_type":"execute_result","data":{"text/plain":"null"},"metadata":{}}]},{"cell_type":"markdown","source":"## Evaluation\n\nWe need to do the evaluation of the accuracy of our model on the numerical labels. Here is the current Schema:","metadata":{}},{"cell_type":"code","source":"predictions.printSchema","metadata":{"trusted":true},"execution_count":63,"outputs":[{"name":"stdout","output_type":"stream","text":"root\n |-- sepalLength: double (nullable = true)\n |-- sepalWidth: double (nullable = true)\n |-- petalLength: double (nullable = true)\n |-- petalWidth: double (nullable = true)\n |-- label: string (nullable = true)\n |-- features: vector (nullable = true)\n |-- indexedLabel: double (nullable = false)\n |-- rawPrediction: vector (nullable = true)\n |-- probability: vector (nullable = true)\n |-- prediction: double (nullable = false)\n |-- predictedLabel: string (nullable = true)\n\n"}]},{"cell_type":"code","source":"import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\n\n// Select (prediction, true label) and compute test error\nval evaluator = new MulticlassClassificationEvaluator()\n .setLabelCol(\"indexedLabel\")\n .setPredictionCol(\"prediction\")\n .setMetricName(\"accuracy\")\nval accuracy = evaluator.evaluate(predictions)\n","metadata":{"trusted":true},"execution_count":64,"outputs":[{"execution_count":64,"output_type":"execute_result","data":{"text/plain":"0.9583333333333334"},"metadata":{}}]},{"cell_type":"code","source":"","metadata":{"trusted":true},"execution_count":null,"outputs":[]}]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment