Skip to content

Instantly share code, notes, and snippets.

@jrabary
Created April 3, 2014 12:42
Show Gist options
  • Save jrabary/9953562 to your computer and use it in GitHub Desktop.
Save jrabary/9953562 to your computer and use it in GitHub Desktop.
RDD cartesian need cached data illustration
object CartesianTest {
case class DataFormat(id: Long, view: String, value: Array[Double])
def randomSplit(
input: RDD[DataFormat],
numFolds: Int,
trainingFraction: Double,
query: String,
target: String,
seed: Long ) = {
val weights = Array(trainingFraction, 1-trainingFraction)
val inputByIds = input.filter( _.view.matches(s"$target|$query")).groupBy(_.id)
val paired = inputByIds.filter( v => v._2.length > 1)
val unPaired = inputByIds.filter( v => v._2.length <= 1)
val rg = new Random(seed)
(1 to numFolds).map {_ =>
val splits = paired.randomSplit(weights, rg.nextInt)
(splits(0).flatMap(_._2), splits(1).union(unPaired).flatMap(_._2))
}.toIterator
}
def main(args: Array[String]) {
val conf = new SparkConf()
.setMaster("local[8]")
.setAppName("CartesianTest")
.setSparkHome(System.getenv("SPARK_HOME"))
.set("spark.executor.memory","2g")
val spark = new SparkContext(conf)
val view1 = (1 to 500).map (DataFormat(_,"view1",new Array[Double](300)))
val view2 = (1 to 500).map (DataFormat(_,"view2",new Array[Double](300)))
val data = spark.parallelize(view1,8).union(spark.parallelize(view2,8))
val folds = randomSplit(data,10,0.5,"view1","view2",0)
val (train,test) = folds.next
// Get a correct result when cached
//val testQuery = test.filter(_.view matches "view1").cache
//val testTarget = test.filter(_.view matches "view2").cache
val testQuery = test.filter(_.view matches "view1")
val testTarget = test.filter(_.view matches "view2")
val testQueryTargetPair = testQuery.cartesian(testTarget)
println(s"${testQuery.count} * ${testTarget.count} = ${testQueryTargetPair.count}")
spark.stop()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment