Skip to content

Instantly share code, notes, and snippets.

@kermitas
Last active November 21, 2017 05:37
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kermitas/41c456c839645ab300d3 to your computer and use it in GitHub Desktop.
Save kermitas/41c456c839645ab300d3 to your computer and use it in GitHub Desktop.
Maybe not perfect but simple calling thread ExecutionContext can help if we already have thread (for example: taken from thread pool) and want to execute few operations in a row.
/**
* [[ExecutionContext]] that will execute actions in calling thread (and by that making them blocking).
*/
class CallingThreadExecutionContext extends ExecutionContext {
override def execute(runnable: Runnable): Unit = runnable.run
override def reportFailure(t: Throwable): Unit = throw t
}
import scala.concurrent.duration._
import org.scalatest.{ FeatureSpec, ShouldMatchers }
import scala.collection.mutable.ListBuffer
object CallingThreadExecutionContextTest {
object test1 {
val expectedResult = Seq(1, 2, 3, 4, 5)
}
object test2 {
val expectedResult = Seq(1, 3, 5)
}
val inFutureSleepTimeInMilliseconds = 250
val futureResultTimeout = 5 seconds
val beVerbose = true
}
class CallingThreadExecutionContextTest extends FeatureSpec with ShouldMatchers {
import CallingThreadExecutionContextTest._
scenario(s"simple test of calling thread ExecutionContext") {
val callingThreadExecutionContext = new CallingThreadExecutionContext
test(callingThreadExecutionContext) shouldEqual test1.expectedResult
val normalThreadExecutionContext = scala.concurrent.ExecutionContext.Implicits.global
test(normalThreadExecutionContext) shouldEqual test2.expectedResult
}
scenario("calling thread ExecutionContext in a series of f.map() and f.flatMap()") {
implicit val callingThreadExecutionContext = new CallingThreadExecutionContext
val threadsId = ListBuffer.empty[Long]
val future: Future[Long] = Future {
// STEP 0
val currentThreadId = Thread.currentThread.getId
threadsId += currentThreadId
currentThreadId
}.map { previousThreadId: Long =>
// STEP 1
val currentThreadId = Thread.currentThread.getId
threadsId += currentThreadId
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 1 (previous $previousThreadId, current $currentThreadId).")
currentThreadId
}.flatMap { previousThreadId: Long =>
// STEP 2
Future {
val currentThreadId = Thread.currentThread.getId
threadsId += currentThreadId
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 2 (previous $previousThreadId, current $currentThreadId).")
currentThreadId
}
}.map { previousThreadId: Long =>
// STEP 3
val currentThreadId = Thread.currentThread.getId
threadsId += currentThreadId
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 3 (previous $previousThreadId, current $currentThreadId).")
currentThreadId
}.map { previousThreadId: Long =>
// STEP 4
val currentThreadId = Thread.currentThread.getId
threadsId += currentThreadId
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 4 (previous $previousThreadId, current $currentThreadId).")
currentThreadId
}.map { previousThreadId: Long =>
// STEP 5
val currentThreadId = Thread.currentThread.getId
threadsId += currentThreadId
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 5 (previous $previousThreadId, current $currentThreadId).")
currentThreadId
}.flatMap { previousThreadId: Long =>
// STEP 6
Future {
val currentThreadId = Thread.currentThread.getId
threadsId += currentThreadId
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 6 (previous $previousThreadId, current $currentThreadId).")
currentThreadId
}
}.map { previousThreadId: Long =>
// STEP 7
val currentThreadId = Thread.currentThread.getId
threadsId += currentThreadId
if (previousThreadId != currentThreadId) throw new Exception(s"Thread id mismatch at step 7 (previous $previousThreadId, current $currentThreadId).")
currentThreadId
}
val completedFuture = Await.ready(future, futureResultTimeout)
if (beVerbose) info(s"threads id: ${threadsId.zipWithIndex.map(_.swap).mkString(",")}")
completedFuture.value.get.get
}
protected def test(ec: ExecutionContext): Seq[Int] = {
var executionOrder = ListBuffer.empty[Int]
executionOrder += 1
Future {
Thread.sleep(inFutureSleepTimeInMilliseconds)
executionOrder += 2
}(ec)
executionOrder += 3
Future {
Thread.sleep(inFutureSleepTimeInMilliseconds)
executionOrder += 4
}(ec)
executionOrder += 5
if (beVerbose) info(executionOrder.toString)
executionOrder
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment