Skip to content

Instantly share code, notes, and snippets.

@zoltanctoth
Created September 13, 2016 20:09
Show Gist options
  • Save zoltanctoth/231c2af0e74b5794a131644ff43ba0d5 to your computer and use it in GitHub Desktop.
Save zoltanctoth/231c2af0e74b5794a131644ff43ba0d5 to your computer and use it in GitHub Desktop.
This is a Spark <-> H2O / Sparkling water deep learning prototype.
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.h2o.{H2OContext, H2OFrame}
import org.apache.spark.sql.DataFrame
import hex.deeplearning.DeepLearning
import water.app.SparkContextSupport
import hex.deeplearning.DeepLearningParameters
import hex.deeplearning.DeepLearningParameters.Activation
import org.apache.spark.h2o.{DoubleHolder, H2OContext, H2OFrame}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext, SparkFiles}
import water.app.SparkContextSupport
import water.fvec.Frame
object DeepRM {
def predict(sparkContext: SparkContext, trainingDF: DataFrame, testDF: DataFrame, labelColumn: String): DataFrame = {
val sqlContext = new SQLContext(sparkContext);
val hc = H2OContext.getOrCreate(sparkContext)
val testFrame = hc.asH2OFrame(testDF)
val trainingFrame = hc.asH2OFrame(trainingDF)
val dlParams = new DeepLearningParameters()
dlParams._train = trainingFrame.key
dlParams._response_column = labelColumn
dlParams._epochs = 20
dlParams._activation = Activation.RectifierWithDropout
dlParams._hidden = Array[Int](100, 100)
val dl = new DeepLearning(dlParams)
val dlModel = dl.trainModel.get
// Predict and add prediction column to the test dataframe
val prediction: Frame = dlModel.score(testFrame)
val ret = testFrame.add(prediction)
val predictedDF = hc.asDataFrame(hc.asH2OFrame(ret))(sqlContext)
hc.stop()
predictedDF
}
def main(args: Array[String]) {
val conf = new SparkConf().setMaster("local").setAppName("H2O Test")
val sc = new SparkContext(conf)
//sc.addJar("/Users/zoltanctoth/src/sparkling-water/assembly/build/libs/sparkling-water-assembly-1.5.99999-SNAPSHOT-all.jar")
//sc.addJar("/opt/spark/lib/spark-assembly-1.5.2-hadoop2.6.0.jar")
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val training = sqlContext.read.parquet("data/training2.parquet")
val test = sqlContext.read.parquet("data/test2.parquet")
val labelCol = "is_late"
val prDF = predict(sc, training, test, labelCol)
println(prDF.show())
sc.stop()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment