Created
September 13, 2016 20:09
-
-
Save zoltanctoth/231c2af0e74b5794a131644ff43ba0d5 to your computer and use it in GitHub Desktop.
This is a Spark <-> H2O / Sparkling water deep learning prototype.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.{SparkConf, SparkContext} | |
import org.apache.spark.h2o.{H2OContext, H2OFrame} | |
import org.apache.spark.sql.DataFrame | |
import hex.deeplearning.DeepLearning | |
import water.app.SparkContextSupport | |
import hex.deeplearning.DeepLearningParameters | |
import hex.deeplearning.DeepLearningParameters.Activation | |
import org.apache.spark.h2o.{DoubleHolder, H2OContext, H2OFrame} | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.SQLContext | |
import org.apache.spark.{SparkConf, SparkContext, SparkFiles} | |
import water.app.SparkContextSupport | |
import water.fvec.Frame | |
object DeepRM { | |
def predict(sparkContext: SparkContext, trainingDF: DataFrame, testDF: DataFrame, labelColumn: String): DataFrame = { | |
val sqlContext = new SQLContext(sparkContext); | |
val hc = H2OContext.getOrCreate(sparkContext) | |
val testFrame = hc.asH2OFrame(testDF) | |
val trainingFrame = hc.asH2OFrame(trainingDF) | |
val dlParams = new DeepLearningParameters() | |
dlParams._train = trainingFrame.key | |
dlParams._response_column = labelColumn | |
dlParams._epochs = 20 | |
dlParams._activation = Activation.RectifierWithDropout | |
dlParams._hidden = Array[Int](100, 100) | |
val dl = new DeepLearning(dlParams) | |
val dlModel = dl.trainModel.get | |
// Predict and add prediction column to the test dataframe | |
val prediction: Frame = dlModel.score(testFrame) | |
val ret = testFrame.add(prediction) | |
val predictedDF = hc.asDataFrame(hc.asH2OFrame(ret))(sqlContext) | |
hc.stop() | |
predictedDF | |
} | |
def main(args: Array[String]) { | |
val conf = new SparkConf().setMaster("local").setAppName("H2O Test") | |
val sc = new SparkContext(conf) | |
//sc.addJar("/Users/zoltanctoth/src/sparkling-water/assembly/build/libs/sparkling-water-assembly-1.5.99999-SNAPSHOT-all.jar") | |
//sc.addJar("/opt/spark/lib/spark-assembly-1.5.2-hadoop2.6.0.jar") | |
val sqlContext = new org.apache.spark.sql.SQLContext(sc) | |
val training = sqlContext.read.parquet("data/training2.parquet") | |
val test = sqlContext.read.parquet("data/test2.parquet") | |
val labelCol = "is_late" | |
val prDF = predict(sc, training, test, labelCol) | |
println(prDF.show()) | |
sc.stop() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment