Skip to content

Instantly share code, notes, and snippets.

@xgdgsc
Forked from AtlasPilotPuppy/SparkHbaseALS.scala
Last active August 29, 2015 14:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xgdgsc/54c9bc225fde77c1de51 to your computer and use it in GitHub Desktop.
Save xgdgsc/54c9bc225fde77c1de51 to your computer and use it in GitHub Desktop.
import org.apache.spark.rdd.NewHadoopRDD
import org.apache.hadoop.hbase.mapreduce.TableInputFormat
import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.hadoop.hbase.client.Result
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.Rating
import scala.collection.mutable.ArrayBuffer
val hbaseConfiguration = (hbaseConfigFileName: String, tableName: String) => {
val hbaseConfiguration = HBaseConfiguration.create()
hbaseConfiguration.addResource(hbaseConfigFileName)
hbaseConfiguration.set(TableInputFormat.INPUT_TABLE, tableName)
hbaseConfiguration
}
val cols = new NewHadoopRDD(sc, classOf[TableInputFormat],classOf[ImmutableBytesWritable],classOf[Result],hbaseConfiguration("/home/hadoop/hbase/conf/hbase-site.xml", "leads_test")).map(tuple => tuple._2).map(
result => result.getColumn("data".getBytes, "person_id".getBytes) :: result.getColumn("data".getBytes, "sold_at".getBytes):: result.getColumn("data".getBytes, "offer_id".getBytes):: Nil)
val row_vals = cols.filter(item => item.map(i=> i.length).reduceLeft((a,b)=>a+b) == 3).map(row => row.map(ele => new String(ele.asScala.reduceLeft{
(a,b) => if (a.getTimestamp > b.getTimestamp) a else b}.getValue.map(_.toChar))))
val cleaned = row_vals.filter(row => row(0) != "None" && row(2) != "None").map(row => row(0) :: (if (row(1) =="None") 0.0 else 1.0) :: row(2) :: Nil)
val summed = cleaned.map(row => ((row(0).toString.toInt,row(2).toString.toInt), row(1).toString.toDouble)).groupByKey(4).reduceByKey((a,b) => ArrayBuffer( a.reduce(_+_) + b.reduce(_+_) ))
val ratings = summed.map(row => Rating(row._1._1, row._1._2, row._2.head))
val model = ALS.train(ratings, 5, 20, 0.01)
val usersProducts = ratings.map { case Rating(user, product, rate) =>
(user, product)
}
val predictions =
model.predict(usersProducts).map { case Rating(user, product, rate) =>
((user, product), rate)
}
val ratesAndPreds = ratings.map { case Rating(user, product, rate) =>
((user, product), rate)
}.join(predictions)
val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) =>
val err = (r1 - r2)
err * err
}.mean()
println("Mean Squared Error = " + MSE)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment