Skip to content

Instantly share code, notes, and snippets.

@AtlasPilotPuppy
Last active August 29, 2015 14:03
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save AtlasPilotPuppy/dc22f128faff450555ef to your computer and use it in GitHub Desktop.
Save AtlasPilotPuppy/dc22f128faff450555ef to your computer and use it in GitHub Desktop.
Uses values in hbase tables to train and test ALS model in MLib.
import org.apache.spark.rdd.NewHadoopRDD
import org.apache.hadoop.hbase.mapreduce.TableInputFormat
import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.hadoop.hbase.client.Result
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.Rating
import scala.collection.mutable.ArrayBuffer
val hbaseConfiguration = (hbaseConfigFileName: String, tableName: String) => {
val hbaseConfiguration = HBaseConfiguration.create()
hbaseConfiguration.addResource(hbaseConfigFileName)
hbaseConfiguration.set(TableInputFormat.INPUT_TABLE, tableName)
hbaseConfiguration
}
val cols = new NewHadoopRDD(sc, classOf[TableInputFormat],classOf[ImmutableBytesWritable],classOf[Result],hbaseConfiguration("/home/hadoop/hbase/conf/hbase-site.xml", "leads_test")).map(tuple => tuple._2).map(
result => result.getColumn("data".getBytes, "person_id".getBytes) :: result.getColumn("data".getBytes, "sold_at".getBytes):: result.getColumn("data".getBytes, "offer_id".getBytes):: Nil)
val row_vals = cols.filter(item => item.map(i=> i.length).reduceLeft((a,b)=>a+b) == 3).map(row => row.map(ele => new String(ele.asScala.reduceLeft{
(a,b) => if (a.getTimestamp > b.getTimestamp) a else b}.getValue.map(_.toChar))))
val cleaned = row_vals.filter(row => row(0) != "None" && row(2) != "None").map(row => row(0) :: (if (row(1) =="None") 0.0 else 1.0) :: row(2) :: Nil)
val summed = cleaned.map(row => ((row(0).toString.toInt,row(2).toString.toInt), row(1).toString.toDouble)).groupByKey(4).reduceByKey((a,b) => ArrayBuffer( a.reduce(_+_) + b.reduce(_+_) ))
val ratings = summed.map(row => Rating(row._1._1, row._1._2, row._2.head))
val model = ALS.train(ratings, 5, 20, 0.01)
val usersProducts = ratings.map { case Rating(user, product, rate) =>
(user, product)
}
val predictions =
model.predict(usersProducts).map { case Rating(user, product, rate) =>
((user, product), rate)
}
val ratesAndPreds = ratings.map { case Rating(user, product, rate) =>
((user, product), rate)
}.join(predictions)
val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) =>
val err = (r1 - r2)
err * err
}.mean()
println("Mean Squared Error = " + MSE)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment