Created
April 3, 2014 12:42
-
-
Save jrabary/9953562 to your computer and use it in GitHub Desktop.
RDD cartesian need cached data illustration
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object CartesianTest { | |
case class DataFormat(id: Long, view: String, value: Array[Double]) | |
def randomSplit( | |
input: RDD[DataFormat], | |
numFolds: Int, | |
trainingFraction: Double, | |
query: String, | |
target: String, | |
seed: Long ) = { | |
val weights = Array(trainingFraction, 1-trainingFraction) | |
val inputByIds = input.filter( _.view.matches(s"$target|$query")).groupBy(_.id) | |
val paired = inputByIds.filter( v => v._2.length > 1) | |
val unPaired = inputByIds.filter( v => v._2.length <= 1) | |
val rg = new Random(seed) | |
(1 to numFolds).map {_ => | |
val splits = paired.randomSplit(weights, rg.nextInt) | |
(splits(0).flatMap(_._2), splits(1).union(unPaired).flatMap(_._2)) | |
}.toIterator | |
} | |
def main(args: Array[String]) { | |
val conf = new SparkConf() | |
.setMaster("local[8]") | |
.setAppName("CartesianTest") | |
.setSparkHome(System.getenv("SPARK_HOME")) | |
.set("spark.executor.memory","2g") | |
val spark = new SparkContext(conf) | |
val view1 = (1 to 500).map (DataFormat(_,"view1",new Array[Double](300))) | |
val view2 = (1 to 500).map (DataFormat(_,"view2",new Array[Double](300))) | |
val data = spark.parallelize(view1,8).union(spark.parallelize(view2,8)) | |
val folds = randomSplit(data,10,0.5,"view1","view2",0) | |
val (train,test) = folds.next | |
// Get a correct result when cached | |
//val testQuery = test.filter(_.view matches "view1").cache | |
//val testTarget = test.filter(_.view matches "view2").cache | |
val testQuery = test.filter(_.view matches "view1") | |
val testTarget = test.filter(_.view matches "view2") | |
val testQueryTargetPair = testQuery.cartesian(testTarget) | |
println(s"${testQuery.count} * ${testTarget.count} = ${testQueryTargetPair.count}") | |
spark.stop() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment