Last active
November 21, 2017 05:37
-
-
Save kermitas/41c456c839645ab300d3 to your computer and use it in GitHub Desktop.
Maybe not perfect but simple calling thread ExecutionContext can help if we already have thread (for example: taken from thread pool) and want to execute few operations in a row.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* [[ExecutionContext]] that will execute actions in calling thread (and by that making them blocking). | |
*/ | |
class CallingThreadExecutionContext extends ExecutionContext { | |
override def execute(runnable: Runnable): Unit = runnable.run | |
override def reportFailure(t: Throwable): Unit = throw t | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.concurrent.duration._ | |
import org.scalatest.{ FeatureSpec, ShouldMatchers } | |
import scala.collection.mutable.ListBuffer | |
object CallingThreadExecutionContextTest { | |
object test1 { | |
val expectedResult = Seq(1, 2, 3, 4, 5) | |
} | |
object test2 { | |
val expectedResult = Seq(1, 3, 5) | |
} | |
val inFutureSleepTimeInMilliseconds = 250 | |
val futureResultTimeout = 5 seconds | |
val beVerbose = true | |
} | |
class CallingThreadExecutionContextTest extends FeatureSpec with ShouldMatchers { | |
import CallingThreadExecutionContextTest._ | |
scenario(s"simple test of calling thread ExecutionContext") { | |
val callingThreadExecutionContext = new CallingThreadExecutionContext | |
test(callingThreadExecutionContext) shouldEqual test1.expectedResult | |
val normalThreadExecutionContext = scala.concurrent.ExecutionContext.Implicits.global | |
test(normalThreadExecutionContext) shouldEqual test2.expectedResult | |
} | |
scenario("calling thread ExecutionContext in a series of f.map() and f.flatMap()") { | |
implicit val callingThreadExecutionContext = new CallingThreadExecutionContext | |
val threadsId = ListBuffer.empty[Long] | |
val future: Future[Long] = Future { | |
// STEP 0 | |
val currentThreadId = Thread.currentThread.getId | |
threadsId += currentThreadId | |
currentThreadId | |
}.map { previousThreadId: Long => | |
// STEP 1 | |
val currentThreadId = Thread.currentThread.getId | |
threadsId += currentThreadId | |
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 1 (previous $previousThreadId, current $currentThreadId).") | |
currentThreadId | |
}.flatMap { previousThreadId: Long => | |
// STEP 2 | |
Future { | |
val currentThreadId = Thread.currentThread.getId | |
threadsId += currentThreadId | |
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 2 (previous $previousThreadId, current $currentThreadId).") | |
currentThreadId | |
} | |
}.map { previousThreadId: Long => | |
// STEP 3 | |
val currentThreadId = Thread.currentThread.getId | |
threadsId += currentThreadId | |
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 3 (previous $previousThreadId, current $currentThreadId).") | |
currentThreadId | |
}.map { previousThreadId: Long => | |
// STEP 4 | |
val currentThreadId = Thread.currentThread.getId | |
threadsId += currentThreadId | |
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 4 (previous $previousThreadId, current $currentThreadId).") | |
currentThreadId | |
}.map { previousThreadId: Long => | |
// STEP 5 | |
val currentThreadId = Thread.currentThread.getId | |
threadsId += currentThreadId | |
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 5 (previous $previousThreadId, current $currentThreadId).") | |
currentThreadId | |
}.flatMap { previousThreadId: Long => | |
// STEP 6 | |
Future { | |
val currentThreadId = Thread.currentThread.getId | |
threadsId += currentThreadId | |
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 6 (previous $previousThreadId, current $currentThreadId).") | |
currentThreadId | |
} | |
}.map { previousThreadId: Long => | |
// STEP 7 | |
val currentThreadId = Thread.currentThread.getId | |
threadsId += currentThreadId | |
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 7 (previous $previousThreadId, current $currentThreadId).") | |
currentThreadId | |
} | |
val completedFuture = Await.ready(future, futureResultTimeout) | |
if (beVerbose) info(s"threads id: ${threadsId.zipWithIndex.map(_.swap).mkString(",")}") | |
completedFuture.value.get.get | |
} | |
protected def test(ec: ExecutionContext): Seq[Int] = { | |
var executionOrder = ListBuffer.empty[Int] | |
executionOrder += 1 | |
Future { | |
Thread.sleep(inFutureSleepTimeInMilliseconds) | |
executionOrder += 2 | |
}(ec) | |
executionOrder += 3 | |
Future { | |
Thread.sleep(inFutureSleepTimeInMilliseconds) | |
executionOrder += 4 | |
}(ec) | |
executionOrder += 5 | |
if (beVerbose) info(executionOrder.toString) | |
executionOrder | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment