Created
June 26, 2014 20:42
-
-
Save erikerlandson/2f2b3dfe03f3266b577c to your computer and use it in GitHub Desktop.
Demonstrate a function that abstracts cross validation for an MLLib model - in this case org.apache.spark.mllib.recommendation.MatrixFactorizationModel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.lang.Math | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.mllib.recommendation.Rating | |
import org.apache.spark.mllib.recommendation.ALS | |
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel | |
import org.apache.spark.mllib.util.MLUtils.kFold | |
// Preload some Rating data for my own convenience | |
val txt = sc.textFile("/home/eje/git/ratorade/data/bgr.dat") | |
val ratings = txt.map(_.split('\t') match { case Array(user, item, rating, _, _) => Rating(user.toInt, item.toInt, rating.toDouble / 100.0)}) | |
// A function that generates cross-val training of an ALS collaborative filter | |
def xvalALS(data: RDD[Rating], nFolds: Int = 10, seed: Int = 77, | |
rank: Int = 10, iter: Int = 20, alpha: Double = 0.01):(Double, Array[Double], Array[MatrixFactorizationModel]) = { | |
// (train, test) pairs | |
val folds = kFold(data, nFolds, seed) | |
// (model, test) pairs | |
val models = folds.map { case (train, test) => (ALS.train(train, rank, iter, alpha), test) } | |
// evaluate model predictions on test data: RDD[((user, product), rating)] | |
val pred = models.map { case (model, test) => model.predict(test.map { case Rating(usr, prod, _) => (usr, prod) }).map { case Rating(usr, prod, rate) => ((usr,prod), rate) } } | |
// reformat truth: RDD[((user, product), rating)] | |
val truth = models.map { case (_, test) => test.map { case Rating(usr,prod,rate) => ((usr,prod),rate) } } | |
// an array of RMSE for each test fold: | |
val rmse = truth.zip(pred).map { case (tr, pd) => tr.join(pd).map { case ((_,_), (t,p)) => math.pow(p-t,2) }.mean() }.map(math.sqrt(_)) | |
// returns tuple: (total RMSE, RMSE per fold, model per fold) | |
(rmse.fold(0.0)(_+_)/rmse.length, rmse, models.map(_._1)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment