Created
August 20, 2015 15:54
-
-
Save juanrh/c719ca82f441f8dacd59 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// TODO: stats not being send, check Direct kafka dstream and other tutorial | |
/** | |
* Much less efficient than DynSeqQueueInputDStream, which is based on List | |
* instead of maps | |
* Following https://github.com/apache/spark/blob/master/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala | |
* */ | |
class TestCaseDictInputDStream[A: ClassTag] | |
(@transient _ssc : StreamingContext) | |
// extends InputDStream[A](_ssc) FIXME remove when consolidated | |
extends InputDStream[(TestCaseIdCounter.TestCaseId, A)](_ssc) | |
with Logging { | |
import TestCaseIdCounter.TestCaseId | |
import scala.collection.immutable.HashMap | |
@transient val _sc = _ssc.sparkContext | |
val numSlices = 2 // TODO add config | |
// (testCaseId, (testCase, currentBatchPointer)) | |
var testCasesMap : Map[TestCaseId, (Vector[Seq[A]], Int)] = emptyTestCasesMap | |
private[this] def emptyTestCasesMap = new HashMap[TestCaseId, (Vector[Seq[A]], Int)] | |
private[this] def reset() : Unit = { testCasesMap = emptyTestCasesMap } | |
import java.io.{NotSerializableException, ObjectOutputStream} | |
private def writeObject(oos: ObjectOutputStream): Unit = { | |
throw new NotSerializableException("queueStream doesn't support checkpointing") | |
} | |
def addTestCase(testCaseId : TestCaseId, testCase : Seq[Seq[A]]) : Unit = synchronized { | |
require(!(testCasesMap contains testCaseId), "test cases should not be added more than once") | |
testCasesMap += testCaseId -> (testCase.toVector, 0) | |
} | |
def removeTestCase(testCaseId : TestCaseId) : Unit = synchronized { | |
testCasesMap -= testCaseId | |
} | |
override def start() : Unit = reset() | |
override def stop() : Unit = reset() | |
// override def compute(validTime: Time): Option[RDD[A]] = synchronized { FIXME remove when consolidated | |
override def compute(validTime: Time): Option[RDD[(TestCaseId, A)]] = synchronized { | |
// TODO: use fold on testCasesMap to produce the RDD an the new testCasesMap | |
// testCasesMap.foldLeft((emptyTestCasesMap, List() : Seq[Seq[A]])) { | |
val testCasesMapAndBatch = | |
testCasesMap.foldLeft((emptyTestCasesMap, List() : Seq[(TestCaseId, A)])) { | |
case ((inTestCasesMap, inBatch), (testCaseId, (testCase, currentBatchPointer))) => { | |
if (currentBatchPointer < testCase.length) { | |
// this test case contributes to the batch | |
// mark the records with testCaseId and add them to the batch | |
val outBatch = testCase(currentBatchPointer).map((testCaseId, _)) ++ inBatch | |
// move the test case pointer | |
val outTestCasesMap = inTestCasesMap + (testCaseId -> (testCase, currentBatchPointer + 1)) | |
(outTestCasesMap, outBatch) | |
} else { | |
// this test case doesn't contribute to the batch: don't remove from the map, | |
// as the case might still fail, this will be removed when the prop calls removeTestCase | |
(inTestCasesMap, inBatch) | |
} | |
} | |
} | |
testCasesMap = testCasesMapAndBatch._1 | |
val batch = testCasesMapAndBatch._2 | |
if (batch.size > 0) { | |
logWarning(s"computing batch ${batch.mkString(",")}") | |
val rdd = _sc.parallelize(batch, numSlices=numSlices) | |
rdd.count // force compute or this does nothing | |
Some(rdd) | |
} else { | |
// None // FIXME is this ok? | |
// Some(_sc.parallelize(List(), 1)) | |
Some(_sc.emptyRDD) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment