Skip to content

Instantly share code, notes, and snippets.

@vlad17
Last active August 9, 2016 22:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vlad17/c22bb4e6679c9e65fcc3e93a92bd3c30 to your computer and use it in GitHub Desktop.
Save vlad17/c22bb4e6679c9e65fcc3e93a92bd3c30 to your computer and use it in GitHub Desktop.
[SPARK-16718] benchmark for million song dataset
// See gbm.R for context
// run with options:
// spark-shell --driver-memory 20G --executor-memory 4G --driver-java-options="-Xss500M" -i gbt.spark
import org.apache.spark.sql.DataFrame
import sys.process._
import java.io._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation._
// Download
val csvLoc = "/tmp/YearPredictionMSD.txt"
val fileLoc = "http://archive.ics.uci.edu/ml/machine-learning-databases/00203/YearPredictionMSD.txt.zip"
if (! new File(csvLoc).exists) {
val _1 = (s"wget -q -O $csvLoc.zip $fileLoc") !
val _2 = (s"unzip -o -qq $csvLoc -d /tmp") !
}
// Extract same way R example does, convert to the binary year-guess problem
val raw = spark.read.option("inferSchema", true).option("header", false).csv(csvLoc)
val firstCol = raw.columns.head
val binary = raw.selectExpr("*", s"cast($firstCol < 2002 as double) as label").drop(firstCol)
// Vectorize
val explanatory = binary.columns.filterNot(_ == "label")
val clean = new VectorAssembler().setInputCols(explanatory).setOutputCol("features").transform(binary)
// Get training/test
val indexed = clean.rdd.zipWithIndex
val cutoff = 463715
def filterIndex(cond: Long => Boolean): DataFrame = {
val rdd = indexed.filter(x => cond(x._2)).map(_._1)
spark.createDataFrame(rdd, clean.schema)
}
val (train, test) = (filterIndex(_ < cutoff), filterIndex(_ >= cutoff))
// Estimators
val ntrees = 700
val shrinkage = 0.001
val varianceBased = new GBTClassifier().setMaxIter(ntrees).setSubsamplingRate(0.75).setMaxDepth(3).setLossType("bernoulli").setMinInstancesPerNode(10).setStepSize(shrinkage).setImpurity("variance").setLabelCol("label").setSeed(123)
val lossBased = varianceBased.setImpurity("loss-based")
train.cache()
val start = System.nanoTime
val varianceModel = varianceBased.fit(train)
val varianceTime = System.nanoTime - start
val start = System.nanoTime
val lossModel = lossBased.fit(train)
val lossTime = System.nanoTime - start
train.unpersist(true)
test.cache()
val predVariance = varianceModel.transform(test)
val predLoss = lossModel.transform(test)
def evaluate(df: DataFrame) = {
val eval = new MulticlassClassificationEvaluator()
for (metric <- Seq("f1", "weightedPrecision", "weightedRecall", "accuracy")) {
println(s" $metric = ${eval.evaluate(df)}")
}
}
println(s"variance impurity perf (midpoint thresh) seconds ${(varianceTime / 1e9).toLong}")
evaluate(predVariance)
println(s"loss-based impurity perf (midpoint thresh) seconds ${(lossTime / 1e9).toLong}")
evaluate(predLoss)
val counts = test.groupBy("label").count().select("count").as[Double].collect()
counts.max / counts.sum
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel).
16/08/09 14:05:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
16/08/09 14:05:56 WARN Utils: Your hostname, vlad-databricks resolves to a loopback address: 127.0.1.1; using 192.168.1.23 instead (on interface enp0s31f6)
16/08/09 14:05:56 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
16/08/09 14:05:57 WARN SparkContext: Use an existing SparkContext, some configuration may not take effect.
Spark context Web UI available at http://192.168.1.23:4040
Spark context available as 'sc' (master = local[*], app id = local-1470776757418).
Spark session available as 'spark'.
Loading /home/vlad/Desktop/gbt.spark...
import org.apache.spark.sql.DataFrame
import sys.process._
import java.io._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation._
csvLoc: String = /tmp/YearPredictionMSD.txt
fileLoc: String = http://archive.ics.uci.edu/ml/machine-learning-databases/00203/YearPredictionMSD.txt.zip
warning: there were two feature warnings; re-run with -feature for details
firstCol: String = _c0
binary: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 89 more fields]
explanatory: Array[String] = Array(_c1, _c2, _c3, _c4, _c5, _c6, _c7, _c8, _c9, _c10, _c11, _c12, _c13, _c14, _c15, _c16, _c17, _c18, _c19, _c20, _c21, _c22, _c23, _c24, _c25, _c26, _c27, _c28, _c29, _c30, _c31, _c32, _c33, _c34, _c35, _c36, _c37, _c38, _c39, _c40, _c41, _c42, _c43, _c44, _c45, _c46, _c47, _c48, _c49, _c50, _c51, _c52, _c53, _c54, _c55, _c56, _c57, _c58, _c59, _c60, _c61, _c62, _c63, _c64, _c65, _c66, _c67, _c68, _c69, _c70, _c71, _c72, _c73, _c74, _c75, _c76, _c77, _c78, _c79, _c80, _c81, _c82, _c83, _c84, _c85, _c86, _c87, _c88, _c89, _c90)
clean: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 90 more fields]
cutoff: Int = 463715
filterIndex: (cond: Long => Boolean)org.apache.spark.sql.DataFrame
train: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 90 more fields]
test: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 90 more fields]
ntrees: Int = 700
shrinkage: Double = 0.001
varianceBased: org.apache.spark.ml.classification.GBTClassifier = gbtc_cef98ba41324
lossBased: varianceBased.type = gbtc_cef98ba41324
16/08/09 14:06:12 WARN Utils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.
res1: train.type = [_c1: double, _c2: double ... 90 more fields]
start: Long = 67129448801407
16/08/09 14:06:16 WARN Executor: 1 block locks were not released by TID = 22:
16/08/09 14:06:16 WARN Executor: 1 block locks were not released by TID = 23:
varianceModel: org.apache.spark.ml.classification.GBTClassificationModel = GBTClassificationModel (uid=gbtc_cef98ba41324) with 700 trees
varianceTime: Long = 1366514285281
start: Long = 68496127610682
16/08/09 14:28:59 WARN Executor: 1 block locks were not released by TID = 51123:
16/08/09 14:28:59 WARN Executor: 1 block locks were not released by TID = 51124:
lossTime: Long = 1390297870579
res2: train.type = [_c1: double, _c2: double ... 90 more fields]
res3: test.type = [_c1: double, _c2: double ... 90 more fields]
predVariance: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 91 more fields]
predLoss: org.apache.spark.sql.DataFrame = [_c1: double, _c2: double ... 91 more fields]
evaluate: (df: org.apache.spark.sql.DataFrame)Unit
variance impurity perf (midpoint thresh) seconds 1366
f1 = 0.6547312025559994
weightedPrecision = 0.6547312025559994
weightedRecall = 0.6547312025559994
accuracy = 0.6547312025559994
loss-based impurity perf (midpoint thresh) seconds 1390
f1 = 0.6547312025559994
weightedPrecision = 0.6547312025559994
weightedRecall = 0.6547312025559994
accuracy = 0.6547312025559994
counts: Array[Double] = Array(26812.0, 24818.0)
res8: Double = 0.5193104784040287
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\ version 2.1.0-SNAPSHOT
/_/
Using Scala version 2.11.8 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_91)
Type in expressions to have them evaluated.
Type :help for more information.
scala> :quit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment