Created
May 23, 2017 19:55
-
-
Save rcmiii/98e563c20eac0bc6332d0b222bcae31c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
spark-shell --jars ~/dev/hadoop-aws-2.7.0.jar --conf 'spark.driver.extraJavaOptions=-Ddata-warehouse-url-read=s3n://grovo-data-warehouse-dev' | |
import com.grovo.data.common.dao.CsvReportDao | |
import com.grovo.data.common.domain.TimePeriod | |
import com.grovo.data.cubes.enterprise.dao.EnterpriseAggregateDao | |
import com.grovo.data.snapshots.enterprise.dao.EnterpriseDataSnapshotsDao | |
import org.apache.spark.sql.{SaveMode, SparkSession} | |
import com.grovo.data.common.etl.SnapshotBuilder | |
import com.grovo.data.common.util.Config.SNAPSHOT_TASKS_PARALLELISM | |
import com.grovo.data.common.util.DateConversions.processingHours | |
import com.grovo.data.common.util.TaskSupport.parallelTasks | |
import com.grovo.data.snapshots.content.dao.LessonDataSnapshotDao | |
import org.apache.spark.sql.{Dataset, SparkSession, Row, SaveMode} | |
import java.util.UUID | |
import java.util.Date | |
import java.sql.Timestamp | |
import java.text.SimpleDateFormat | |
import java.util.{Date, TimeZone, UUID} | |
import com.grovo.data.common.dao.CsvReportDao | |
import com.grovo.data.common.domain.Predicates | |
import com.grovo.data.common.domain.Predicates.EnterpriseId | |
import com.grovo.data.common.domain.TimePeriod.Days30 | |
import com.grovo.data.snapshots.identity.dao.UserDataSnapshotDao | |
import com.grovo.data.snapshots.identity.domain.UserDataSnapshot | |
import org.apache.spark.sql.functions._ | |
import com.grovo.data.common.util.DataFields._ | |
import com.grovo.data.cubes.lessonsession.dao.LessonSessionDetailDao | |
import org.apache.spark.ml.feature.StringIndexer | |
import org.apache.spark.mllib.recommendation.Rating | |
import org.apache.spark.mllib.recommendation.ALS | |
case class RecommendedLesson(lessonId: String, rating: Double) | |
case class UserRecommendations(userId: String, lessons: Array[Rating]) | |
val spark: SparkSession = SparkSession | |
.builder() | |
.master("local") | |
.appName("Testing") | |
.getOrCreate() | |
import spark.implicits._ | |
implicit val saveMode = SaveMode.ErrorIfExists | |
TimeZone.setDefault(TimeZone.getTimeZone("GMT")) | |
val fmt = new SimpleDateFormat("MM/dd/yyyy HH:mm") | |
fmt.setTimeZone(TimeZone.getDefault) | |
val date = fmt.parse("05/16/2017 04:00") | |
val lessonSessionDetailDao = new LessonSessionDetailDao | |
val lessonDataSnapshotDao = new LessonDataSnapshotDao | |
val period = TimePeriod.Days30 | |
val galaxy_enterprise_id= "e5463edc-7306-4062-a3da-d9f19b04047f" | |
val lessonSessionDetails = lessonSessionDetailDao.readFromDataWarehouse(spark, date, period, None).where($"userAssignmentId".isNull) | |
val grovoLessonData = lessonDataSnapshotDao.readFromDataWarehouse(spark, date, None).where($"enterpriseId".equalTo(galaxy_enterprise_id)).select($"lessonId").distinct | |
val lessonIndexer = new StringIndexer | |
lessonIndexer.setInputCol("lessonId") | |
lessonIndexer.setOutputCol("lessonNumber") | |
val liModel = lessonIndexer.fit(grovoLessonData) | |
val indexedLessons = liModel.transform(grovoLessonData) | |
val grovoLessonSessionDetails = indexedLessons.join(lessonSessionDetails, Seq("lessonId")) | |
.groupBy($"userId", $"lessonId", $"lessonNumber") | |
.agg( | |
count($"*") as "totalViews", | |
sum(when($"consumedLesson" === true, 1).otherwise(0)) as "totalCompletions" | |
) | |
.select( | |
$"userId", | |
$"lessonId", | |
$"lessonNumber", | |
($"totalViews" + $"totalCompletions" * 2) as "userLessonScore" | |
) | |
val userIndexer = new StringIndexer | |
userIndexer.setInputCol("userId") | |
userIndexer.setOutputCol("userNumber") | |
val userIndexModel = userIndexer.fit(grovoLessonSessionDetails) | |
val indexedSessions = userIndexModel.transform(grovoLessonSessionDetails).cache | |
val ratings = indexedSessions.map(row => { | |
Rating( | |
getFieldValue[Double](row, "userNumber").get.toInt, | |
getFieldValue[Double](row, "lessonNumber").get.toInt, | |
getFieldValue[Long](row, "userLessonScore").get | |
) | |
}) | |
val rank = 10 | |
val numIterations = 10 | |
val alpha = 0.01 | |
val lambda = 0.01 | |
val model = ALS.trainImplicit(ratings.rdd, rank, numIterations, lambda, alpha) | |
//get recommendations from model: | |
val reccs = spark.sparkContext.parallelize(model.recommendProducts(15, 100)).toDS | |
//show recommendations with already viewed items filtered | |
reccs | |
.select($"user", $"product") | |
.except(ratings.where($"user" === 15.0).select($"user", $"product")) | |
.join(reccs, Seq("user", "product")) | |
.orderBy($"rating".desc).show | |
indexedSessions.select($"userid", $"userNumber").distinct.foreach(row => { | |
val userNumber = getFieldValue[Double](row, "userNumber").get.toInt | |
val reccs = spark.sparkContext.parallelize(model.recommendProducts(userNumber, 10)).toDS | |
reccs.show | |
}) | |
val generateRecommendations = udf((userNumber: Int) => { | |
model.recommendProducts(userNumber, 100) | |
}) | |
indexedSessions.select($"userNumber", $"userId").distinct().limit(10).select($"userId", generateRecommendations($"userNumber")).show |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment