Skip to content

Instantly share code, notes, and snippets.

@dwickern
Last active August 29, 2015 14:09
Show Gist options
  • Save dwickern/4d93a2bdf993a9e59ba3 to your computer and use it in GitHub Desktop.
Save dwickern/4d93a2bdf993a9e59ba3 to your computer and use it in GitHub Desktop.
Akka Testkit mixin for testing unordered messages
import akka.testkit._
import org.scalatest._
import scala.collection.mutable
import scala.concurrent.duration._
import scala.reflect.ClassTag
import scala.util._
/**
* Extensions to [[akka.testkit.TestKit]] for testing actors
* which publish messages in no particular order.
*
* APIs are analogous to the `TestKit` expectation methods.
* The expectations are evaluated at the end of each test.
*/
trait UnorderedTestKit extends SuiteMixin { self: Suite with TestKitBase =>
def expectUnorderedMsg(value: Any): Unit = expectUnorderedMsg(defaultTimeout, null, value)
def expectUnorderedMsg(max: FiniteDuration, value: Any): Unit = expectUnorderedMsg(max, null, value)
def expectUnorderedMsg(hint: String, value: Any): Unit = expectUnorderedMsg(defaultTimeout, hint, value)
def expectUnorderedMsg(max: FiniteDuration, hint: String, value: Any): Unit = {
expectInternal(max, s"Did not encounter message equal to $value", hint) {
case `value` =>
}
}
def expectUnorderedNoMsg(): Unit = expectUnorderedNoMsg(defaultTimeout)
def expectUnorderedNoMsg(max: FiniteDuration): Unit = {
expectNoMsgFor = max
}
def expectUnorderedMsgType[T](implicit t: ClassTag[T]): Unit = expectUnorderedMsgClass(defaultTimeout, t.runtimeClass)
def expectUnorderedMsgType[T](max: FiniteDuration)(implicit t: ClassTag[T]): Unit = expectUnorderedMsgClass(max, t.runtimeClass)
def expectUnorderedMsgClass(c: Class[_]): Unit = expectUnorderedMsgClass(defaultTimeout, c)
def expectUnorderedMsgClass(max: FiniteDuration, c: Class[_]): Unit = {
expectInternal(max, s"Did not encounter message of type ${c.getName}", null) {
case msg if c.isInstance(msg) =>
}
}
def expectUnorderedMsgPF(max: FiniteDuration = defaultTimeout, hint: String = null)(pf: PartialFunction[Any, Unit]): Unit = {
expectInternal(max, "Did not encounter message matching the partial function", hint)(pf)
}
/** The expectations to evaluate at the end of the current test */
private val expectations = mutable.Set[Expectation]()
private var expectNoMsgFor: FiniteDuration = Duration.Zero
private def defaultTimeout = testKitSettings.SingleExpectDefaultTimeout.dilated
private def expectInternal(max: FiniteDuration, message: String, hint: String)(pf: PartialFunction[Any, Unit]): Unit = {
expectations += new Expectation {
var lastError: Throwable = _
def execute(msg: Any): Boolean = {
if (pf.isDefinedAt(msg)) {
Try(pf(msg)) match {
case Success(_) => true
case Failure(t) =>
lastError = t
false
}
} else false
}
def duration = max
override def toString = {
if (lastError eq null) {
Option(hint).fold(message) { h => s"$message ($h)" }
} else lastError.toString
}
}
}
private def verifyExpectations(): Unit = {
// use the sum of all expectations as the maximum duration
val max = expectations.foldLeft(Duration.Zero) {
case (accum, exp) => accum + exp.duration
}
val unhandled = mutable.Set[Any]()
try {
if (expectations.nonEmpty) {
within(max) {
while (expectations.nonEmpty) {
fishForMessage() {
case Expectation(xp) =>
expectations -= xp
true
case msg =>
unhandled += msg
false
}
}
if (expectNoMsgFor > Duration.Zero) {
val msg = receiveOne(expectNoMsgFor)
if (msg != null) {
throw new AssertionError(s"Expected no additional messages but received: $msg")
}
}
}
}
} catch {
case ex: AssertionError if expectations.nonEmpty =>
val message = s"""
|${expectations.size} expectation(s) were unmet:
|${expectations.toSeq.zipWithIndex.map { case (exp, i) => s"\t(${i + 1}) $exp" }.mkString("\n")}
|${unhandled.size} message(s) were unhandled:
|${unhandled.map("\t" + _).mkString("\n")}
""".stripMargin
throw new AssertionError(message, ex)
}
}
protected abstract override def withFixture(test: NoArgTest): Outcome = {
expectations.clear()
expectNoMsgFor = Duration.Zero
val result = super.withFixture(test)
verifyExpectations()
result
}
trait Expectation {
def execute(msg: Any): Boolean
def duration: FiniteDuration
}
object Expectation {
/** Tests the `msg` against all of the expectations, and extracts the matching expectation if there is one */
def unapply(msg: Any): Option[Expectation] = expectations.find(_.execute(msg))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment