Skip to content

Instantly share code, notes, and snippets.

@erikerlandson
Created June 26, 2014 20:42
Show Gist options
  • Save erikerlandson/2f2b3dfe03f3266b577c to your computer and use it in GitHub Desktop.
Save erikerlandson/2f2b3dfe03f3266b577c to your computer and use it in GitHub Desktop.
Demonstrate a function that abstracts cross validation for an MLLib model - in this case org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import java.lang.Math
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.recommendation.Rating
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import org.apache.spark.mllib.util.MLUtils.kFold
// Preload some Rating data for my own convenience
val txt = sc.textFile("/home/eje/git/ratorade/data/bgr.dat")
val ratings = txt.map(_.split('\t') match { case Array(user, item, rating, _, _) => Rating(user.toInt, item.toInt, rating.toDouble / 100.0)})
// A function that generates cross-val training of an ALS collaborative filter
def xvalALS(data: RDD[Rating], nFolds: Int = 10, seed: Int = 77,
rank: Int = 10, iter: Int = 20, alpha: Double = 0.01):(Double, Array[Double], Array[MatrixFactorizationModel]) = {
// (train, test) pairs
val folds = kFold(data, nFolds, seed)
// (model, test) pairs
val models = folds.map { case (train, test) => (ALS.train(train, rank, iter, alpha), test) }
// evaluate model predictions on test data: RDD[((user, product), rating)]
val pred = models.map { case (model, test) => model.predict(test.map { case Rating(usr, prod, _) => (usr, prod) }).map { case Rating(usr, prod, rate) => ((usr,prod), rate) } }
// reformat truth: RDD[((user, product), rating)]
val truth = models.map { case (_, test) => test.map { case Rating(usr,prod,rate) => ((usr,prod),rate) } }
// an array of RMSE for each test fold:
val rmse = truth.zip(pred).map { case (tr, pd) => tr.join(pd).map { case ((_,_), (t,p)) => math.pow(p-t,2) }.mean() }.map(math.sqrt(_))
// returns tuple: (total RMSE, RMSE per fold, model per fold)
(rmse.fold(0.0)(_+_)/rmse.length, rmse, models.map(_._1))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment