Skip to content

Instantly share code, notes, and snippets.

@johnynek
Last active January 30, 2020 03:54
Show Gist options
  • Save johnynek/48ab7f060dda1bf0edc7be8229576048 to your computer and use it in GitHub Desktop.
Save johnynek/48ab7f060dda1bf0edc7be8229576048 to your computer and use it in GitHub Desktop.
A minimal example of a test runner for https://github.com/sbt/test-interface. Any comments or reviews welcome.
/**
* copyright 2020 P. Oscar Boykin <oscar.boykin@gmail.com>
* licensed under Apache V2 license: https://www.apache.org/licenses/LICENSE-2.0.html
* or, at your option, GPLv2 or later: https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html
*/
package testrunner
import io.github.classgraph._
import java.io.PrintWriter
import java.net.{URL, URLClassLoader}
import java.nio.file.{Path, Paths}
import sbt.testing._
import scala.collection.immutable.Queue
import scala.collection.JavaConverters._
object TestRunner {
sealed abstract class Select {
import Select._
final def toSelector: Selector =
this match {
case TestNameContains(n) => new TestWildcardSelector(n)
case EntireSuite => suiteSelector
case NestedTest(s, t) => new NestedTestSelector(s, t)
case NestedSuite(s) => new NestedSuiteSelector(s)
}
}
object Select {
private val suiteSelector: Selector = new SuiteSelector
case class TestNameContains(name: String) extends Select
case object EntireSuite extends Select
case class NestedTest(suiteName: String, testName: String) extends Select
case class NestedSuite(suiteName: String) extends Select
}
sealed abstract class Explicitness(val isExplicit: Boolean)
object Explicitness {
case object Explicit extends Explicitness(true)
case object NotExplicit extends Explicitness(false)
}
case class TestClassName(name: String)
sealed abstract class Scanned {
def findFrameworks: List[Framework]
def findTestClassesFromJars(f: Framework, testJars: ClassInfoList): List[(TestClassName, Fingerprint)]
def allClasses: ClassInfoList
def close(): Unit
}
def scanJars(cp: List[Path]): Scanned =
fromClassGraph(
(new ClassGraph)
.overrideClasspath(cp.asJava)
.ignoreParentClassLoaders())
def scanClassLoader(cl: ClassLoader): Scanned =
fromClassGraph((new ClassGraph).addClassLoader(cl))
private def fromClassGraph(withCL: ClassGraph): Scanned =
new Scanned {
val scan = withCL
.enableClassInfo
.enableMethodInfo // need this to check for constructors
.enableAnnotationInfo // need this for annotation scanning
.ignoreMethodVisibility // we need to see private constructors
.scan
def close() = scan.close()
def allClasses: ClassInfoList = scan.getAllClasses
def subsOf(nm: String): ClassInfoList = {
val ci = scan.getClassInfo(nm)
if (ci.isInterface()) scan.getClassesImplementing(nm)
else scan.getSubclasses(nm)
}
def findFrameworks: List[Framework] = {
val fwclass = classOf[Framework]
val frameworks = subsOf(fwclass.getName)
frameworks
.iterator
.asScala
.map { ci =>
val cls = ci.loadClass()
val obj = cls.newInstance()
fwclass.cast(obj)
}
.toList
}
def moduleFilter(mod: Boolean, names: Iterator[ClassInfo]): Iterator[(ClassInfo, TestClassName)] =
if (mod) names.flatMap { ci =>
val n = ci.getName()
if (n.endsWith("$")) (ci, TestClassName(n.dropRight(1))) :: Nil
else Nil
}
else names.flatMap { ci =>
val n = ci.getName()
if (n.endsWith("$")) Nil
else (ci, TestClassName(n)) :: Nil
}
private def findSubClasses(subs: Seq[SubclassFingerprint], tests: ClassInfoList): List[(TestClassName, Fingerprint)] =
subs.groupBy(_.superclassName)
.iterator
.flatMap { case (className, subs) =>
// these are all candidates, but we need to ensure they match the other requirements
val candidates = subsOf(className).intersect(tests)
// there are four possibilities now:
// reqNoArg = { true, false } x isModule = { true, false }
def hasNoArg(ci: ClassInfo): Boolean =
ci.getConstructorInfo
.asScala
.exists { mi =>
mi.getParameterInfo().isEmpty
}
val selectedNames: Seq[(ClassInfo, TestClassName, Fingerprint)] =
subs.flatMap { sub =>
val c1 = candidates.iterator.asScala
val c2 = if (sub.requireNoArgConstructor) c1.filter(hasNoArg) else c1
moduleFilter(sub.isModule, c2).map { case (ci, t) => (ci, t, sub) }
}
val notSelected = candidates.exclude(new ClassInfoList(selectedNames.map(_._1).asJava))
// all the notSelected items should be abstract or we likely have a programming error
val errs = notSelected.iterator.asScala.filterNot { ci => ci.isAbstract || ci.isInterface }
if (errs.nonEmpty) {
sys.error(s"tests: ${errs.toList} ignored, this is likely an error with your tests")
}
else {
selectedNames.map { case (_, t, f) => (t, f) }
}
}
.toList
private def findAnnotated(anns: Seq[AnnotatedFingerprint], tests: ClassInfoList): List[(TestClassName, Fingerprint)] =
anns.groupBy(_.annotationName)
.iterator
.flatMap { case (aname, anns) =>
val classes = scan.getClassesWithAnnotation(aname)
val methods = scan.getClassesWithMethodAnnotation(aname)
val candidates = classes.union(methods).intersect(tests)
val selectedNames: Seq[(ClassInfo, TestClassName, Fingerprint)] =
anns.flatMap { ann =>
val c1 = candidates.iterator.asScala
moduleFilter(ann.isModule, c1).map { case (ci, t) => (ci, t, ann) }
}
val notSelected = candidates.exclude(new ClassInfoList(selectedNames.map(_._1).asJava))
// all the notSelected items should be abstract or we likely have a programming error
val errs = notSelected.iterator.asScala.filterNot { ci => ci.isAbstract || ci.isInterface }
if (errs.nonEmpty) {
sys.error(s"tests: ${errs.toList} ignored, this is likely an error with your tests")
}
else {
selectedNames.map { case (_, t, f) => (t, f) }
}
}
.toList
def findTestClassesFromJars(f: Framework, tests: ClassInfoList): List[(TestClassName, Fingerprint)] = {
val fps = f.fingerprints
fps.foreach {
case _: SubclassFingerprint => ()
case _: AnnotatedFingerprint => ()
case other => sys.error(s"unknown fingerprint type: $other")
}
val subclasses = fps.collect { case sc: SubclassFingerprint => sc }
val anns = fps.collect { case an: AnnotatedFingerprint => an }
(findSubClasses(subclasses, tests) ::: findAnnotated(anns, tests))
.distinct
.sortBy(_._1.name)
}
}
def makeTask(cn: TestClassName, fp: Fingerprint, explicit: Explicitness, selectors: Seq[Select]): TaskDef =
new TaskDef(cn.name, fp, explicit.isExplicit, selectors.iterator.map(_.toSelector).toArray)
/**
* We don't want to depend on cats just to get a few functions
*/
trait Monoid[A] { self =>
def empty: A
def combine(a: A, b: A): A
def combineAll(i: Iterable[A]): A =
if (i.isEmpty) empty
else i.reduce(combine(_, _))
}
object Monoid {
def apply[A](implicit monoid: Monoid[A]): Monoid[A] = monoid
implicit def mapMonoid[A, B](implicit B: Monoid[B]): Monoid[Map[A, B]] =
new Monoid[Map[A, B]] {
def empty = Map.empty
def combine(a: Map[A, B], b: Map[A, B]): Map[A, B] = {
val (fn, big, small) =
if (a.size >= b.size) ({(a: B, b: B) => B.combine(a, b) }, a, b)
else ({(a: B, b: B) => B.combine(b, a) }, b, a)
small.foldLeft(big) { case (big, (k, sv)) =>
big.get(k) match {
case None => big.updated(k, sv)
case Some(bv) => big.updated(k, fn(bv, sv))
}
}
}
}
}
implicit class ListFoldMap[A](val list: List[A]) extends AnyVal {
def foldMap[B](fn: A => B)(implicit m: Monoid[B]): B = {
@annotation.tailrec
def loop(list: List[A], acc: B): B =
list match {
case Nil => acc
case h :: t => loop(t, m.combine(acc, fn(h)))
}
list match {
case Nil => m.empty
case h :: t => loop(t, fn(h))
}
}
}
case class Statistics(count: Long, totalDuration: Option[Long], failures: List[Throwable]) {
def +(that: Statistics): Statistics = {
val dur = for {
a <- totalDuration
b <- that.totalDuration
} yield a + b
Statistics(count + that.count, dur, failures ::: that.failures)
}
}
object Statistics {
private val emptyNone: Statistics = Statistics(1L, None, Nil)
def fromDur(d: Long, failure: Option[Throwable]): Statistics =
if (d < 0) {
failure match {
case None => emptyNone
case Some(h) => Statistics(1L, None, h :: Nil)
}
}
else Statistics(1L, Some(d), failure.toList)
implicit val statisticsMonoid: Monoid[Statistics] =
new Monoid[Statistics] {
val empty: Statistics = Statistics(0L, Some(0L), Nil)
def combine(a: Statistics, b: Statistics) = a + b
}
def fromEvent(e: Event): Statistics = {
val optT = e.throwable()
fromDur(
e.duration(),
if (optT.isEmpty) None else Some(optT.get()))
}
}
sealed abstract class LogKind(val kind: String)
object LogKind {
case object Debug extends LogKind("DEBUG")
case object Error extends LogKind("ERROR")
case object Info extends LogKind("INFO")
case object Warn extends LogKind("WARN")
}
class AggregatingLogger extends Logger {
private val mutex = new AnyRef
private var queue: Queue[Either[Throwable, (LogKind, String)]] =
Queue.empty
def getLog(): Queue[Either[Throwable, (LogKind, String)]] =
mutex.synchronized { queue }
private def put(lk: LogKind, s: String): Unit =
mutex.synchronized {
queue = queue :+ Right((lk, s))
}
def ansiCodesSupported(): Boolean = false
def debug(d: String): Unit =
put(LogKind.Debug, d)
def error(e: String): Unit =
put(LogKind.Error, e)
def info(i: String): Unit =
put(LogKind.Info, i)
def trace(t: Throwable): Unit =
mutex.synchronized {
queue = queue :+ Left(t)
}
def warn(w: String): Unit =
put(LogKind.Warn, w)
}
class AggregatingEventHandler extends EventHandler {
private val mutex = new AnyRef
private var stats = Map.empty[Status, Statistics]
def getStats(): Map[Status, Statistics] =
mutex.synchronized {
stats
}
def handle(e: Event): Unit =
mutex.synchronized {
stats = Monoid[Map[Status, Statistics]].combine(stats, Map(e.status() -> Statistics.fromEvent(e)))
}
}
case class TaskResult(log: Queue[Either[Throwable, (LogKind, String)]], stats: Map[Status, Statistics])
object TaskResult {
implicit val taskResultMonoid: Monoid[TaskResult] =
new Monoid[TaskResult] {
val empty: TaskResult = TaskResult(Queue.empty, Map.empty)
def combine(a: TaskResult, b: TaskResult): TaskResult =
TaskResult(
a.log ++ b.log,
Monoid[Map[Status, Statistics]].combine(a.stats, b.stats))
}
}
trait HandlerLogger {
def eventHandler: EventHandler
def loggers: Seq[Logger]
def results(): TaskResult
}
def makeHandlerLogger(): HandlerLogger =
new HandlerLogger {
val log = new AggregatingLogger
val eventHandler = new AggregatingEventHandler
val loggers = log :: Nil
def results() =
TaskResult(log.getLog(), eventHandler.getStats())
}
def executeAll(task: Task): Map[TaskDef, TaskResult] = {
val hl = makeHandlerLogger()
val loggers = hl.loggers.toArray
val next = task.execute(hl.eventHandler, loggers)
Monoid[Map[TaskDef, TaskResult]]
.combine(
Map(task.taskDef -> hl.results),
next.toList.foldMap(executeAll(_)))
}
def complete(fw: Framework, doneMessage: String, results: Map[TaskDef, TaskResult], out: PrintWriter): Unit = {
def printResult(fqn: String, tr: TaskResult): Unit = {
out.println(s"test class: $fqn")
tr.log.foreach {
case Left(err) =>
out.println("")
err.printStackTrace(out)
out.println("")
case Right((k, msg)) =>
out.println(s"${k.kind}: $msg")
}
// finally the stats
out.println("")
tr
.stats
.toList
.map { case (s, stat) => (s.toString, stat) }
.sortBy(_._1)
.foreach { case (k, Statistics(cnt, optTime, errs)) =>
val time = optTime match {
case None => ""
case Some(t) =>
val s = t.toDouble / 1000.0
s", $s seconds"
}
out.println(s"$k: count = $cnt$time, ${errs.size} errors")
}
}
out.println(s"results from: ${fw.name}")
results
.toList
.sortBy(_._1.fullyQualifiedName)
.foreach { case (td, res) =>
printResult(td.fullyQualifiedName, res)
}
if (doneMessage.nonEmpty) out.println(doneMessage)
}
/**
* cl should be full the full Classpath including all the test framework and the tests
* testJars should be the paths to the jars that contain the tests for this run
*/
def runAll(cl: ClassLoader, testJars: List[Path], args: Array[String], remoteArgs: Array[String], out: PrintWriter): Unit = {
val scannedFW = scanClassLoader(cl)
val fws: List[Framework] = scannedFW.findFrameworks
val scannedTests = scanJars(testJars)
val tests: List[(String, List[(TestClassName, Fingerprint, Framework)])] =
fws
.iterator
.flatMap { fw =>
scannedFW.findTestClassesFromJars(fw, scannedTests.allClasses)
.map { case (tn, fp) =>
(tn, fp, fw)
}
}
.toList
.groupBy(_._3.name)
.toList
.sortBy(_._1)
scannedFW.close()
scannedTests.close()
if (tests.isEmpty) sys.error(s"found 0 tests. Frameworks: $fws, $tests")
else {
val allResults = tests.foldMap {
case (name, Nil) => Monoid[TaskResult].empty
case (name, ne@(h :: _)) =>
val fw = h._3
val runner = fw.runner(args, remoteArgs, cl)
val taskDefs = ne.map { case (cn, fp, _) => makeTask(cn, fp, Explicitness.NotExplicit, Select.EntireSuite :: Nil) }
val tasks = runner.tasks(taskDefs.toArray)
val results: Map[TaskDef, TaskResult] = tasks.toList.foldMap(executeAll(_))
val message = runner.done()
complete(fw, message, results, out)
Monoid[TaskResult].combineAll(results.values)
}
val failed = allResults.stats.getOrElse(Status.Failure, Monoid[Statistics].empty).count
val errored = allResults.stats.getOrElse(Status.Error, Monoid[Statistics].empty).count
if (failed != 0L || errored != 0L) sys.error(s"test had $failed failures, $errored errors")
else ()
}
}
def makeClassLoader(paths: List[Path]): ClassLoader = {
val root = Thread.currentThread.getContextClassLoader()
def toURL(p: Path): URL = new URL(s"file://${p.toString}")
new URLClassLoader(
paths.iterator.map(toURL).toArray,
root)
}
/**
* args:
* --dep_jar: a jar with a dependency of the test framework
* --test_jar: a jar containing a test
* --arg: an arg to the test framework
* --remote_arg: an arg to remotely started JVMs
* --output: the path to write into
*/
def main(args: Array[String]): Unit = {
val deps = List.newBuilder[Path]
val tests = List.newBuilder[Path]
val testArgs = List.newBuilder[String]
val remoteArgs = List.newBuilder[String]
var output: String = ""
@annotation.tailrec
def loop(idx: Int): Unit =
if ((idx + 1) >= args.length) ()
else {
args(idx) match {
case "--dep_jar" =>
deps += Paths.get(args(idx + 1))
case "--test_jar" =>
tests += Paths.get(args(idx + 1))
case "--arg" =>
testArgs += args(idx + 1)
case "--remote_arg" =>
remoteArgs += args(idx + 1)
case "--output" =>
output = args(idx + 1)
case other => sys.error(s"unexpected arg $other in ${args.toList}")
}
loop(idx + 2)
}
loop(0)
if (output == "") sys.error(s"expected to get a non-empty output argument, didn't find one in: ${args.toList}")
val testJars = tests.result()
val cl = makeClassLoader(deps.result() ::: testJars)
val out = new PrintWriter(Paths.get(output).toFile, "UTF-8")
try {
runAll(cl, testJars, testArgs.result().toArray, remoteArgs.result().toArray, out)
}
finally {
out.close()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment