Skip to content

Instantly share code, notes, and snippets.

@vinaysshenoy
Created September 12, 2020 05:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vinaysshenoy/eba388d52c286528b7d8473188a5db8c to your computer and use it in GitHub Desktop.
Save vinaysshenoy/eba388d52c286528b7d8473188a5db8c to your computer and use it in GitHub Desktop.
Composite JUnit5 extension for composing a group of domain specific extensions
import org.junit.jupiter.api.extension.AfterAllCallback
import org.junit.jupiter.api.extension.AfterEachCallback
import org.junit.jupiter.api.extension.AfterTestExecutionCallback
import org.junit.jupiter.api.extension.BeforeAllCallback
import org.junit.jupiter.api.extension.BeforeEachCallback
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback
import org.junit.jupiter.api.extension.ConditionEvaluationResult
import org.junit.jupiter.api.extension.ExecutionCondition
import org.junit.jupiter.api.extension.Extension
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.api.extension.InvocationInterceptor
import org.junit.jupiter.api.extension.LifecycleMethodExecutionExceptionHandler
import org.junit.jupiter.api.extension.ParameterContext
import org.junit.jupiter.api.extension.ParameterResolver
import org.junit.jupiter.api.extension.TestExecutionExceptionHandler
import org.junit.jupiter.api.extension.TestInstanceFactory
import org.junit.jupiter.api.extension.TestInstanceFactoryContext
import org.junit.jupiter.api.extension.TestInstancePostProcessor
import org.junit.jupiter.api.extension.TestInstancePreDestroyCallback
import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider
import org.junit.jupiter.api.extension.TestWatcher
import java.lang.reflect.Parameter
import java.util.Optional
/**
* This is a custom [Extension] that can be used to group various other annotations
* together.
*
* This is useful for cases where you want certain extensions to be added to different
* test classes, but want to avoid creating base classes and inheriting from them.
**/
class CompositeExtension(
private var extensions: List<Extension> = emptyList()
) : ExecutionCondition,
TestInstanceFactory,
TestInstancePostProcessor,
TestInstancePreDestroyCallback,
ParameterResolver,
TestWatcher,
BeforeAllCallback,
BeforeEachCallback,
BeforeTestExecutionCallback,
AfterTestExecutionCallback,
AfterEachCallback,
AfterAllCallback,
TestExecutionExceptionHandler,
LifecycleMethodExecutionExceptionHandler {
init {
validateRegisteredExtensions()
}
private val unsupportedExtensionTypes = setOf(
InvocationInterceptor::class.java,
TestTemplateInvocationContextProvider::class.java
)
private val parameterResolverLookup = mutableMapOf<Parameter, ParameterResolver>()
fun register(extension: Extension): CompositeExtension {
extensions += extension
validateRegisteredExtensions()
return this
}
private fun validateRegisteredExtensions() {
extensions.forEach { extension ->
val firstUnsupportedExtensionType = unsupportedExtensionTypes.firstOrNull {
it.isAssignableFrom(
extension::class.java
)
}
if (firstUnsupportedExtensionType != null) {
throw IllegalStateException(
"Extensions of type ${firstUnsupportedExtensionType.name} are not supported yet!"
)
}
}
}
override fun evaluateExecutionCondition(context: ExtensionContext): ConditionEvaluationResult {
val results = findExtensions<ExecutionCondition>().map { it.evaluateExecutionCondition(context) }
val firstDisabledResult = results.find { it.isDisabled }
return firstDisabledResult ?: results.last()
}
override fun createTestInstance(
factoryContext: TestInstanceFactoryContext,
extensionContext: ExtensionContext
): Any {
val instanceFactories = findExtensions<TestInstanceFactory>()
return when {
instanceFactories.isEmpty() -> factoryContext.testClass.newInstance()
instanceFactories.size > 1 -> {
val foundFactories = instanceFactories.joinToString { it.javaClass.name }
throw IllegalStateException(
"Only one `TestInstanceFactory` can be registered! Found: [$foundFactories]"
)
}
else -> instanceFactories.first().createTestInstance(factoryContext, extensionContext)
}
}
override fun postProcessTestInstance(testInstance: Any, context: ExtensionContext) {
findExtensions<TestInstancePostProcessor>().forEach { it.postProcessTestInstance(testInstance, context) }
}
override fun preDestroyTestInstance(context: ExtensionContext) {
findExtensions<TestInstancePreDestroyCallback>().forEach { it.preDestroyTestInstance(context) }
}
override fun supportsParameter(parameterContext: ParameterContext, extensionContext: ExtensionContext): Boolean {
val extensionThatCanHandleParameter = findExtensions<ParameterResolver>().find {
it.supportsParameter(parameterContext, extensionContext)
}
if (extensionThatCanHandleParameter != null) {
parameterResolverLookup[parameterContext.parameter] = extensionThatCanHandleParameter
}
return extensionThatCanHandleParameter != null
}
override fun resolveParameter(parameterContext: ParameterContext, extensionContext: ExtensionContext): Any {
return parameterResolverLookup.getValue(parameterContext.parameter).resolveParameter(
parameterContext,
extensionContext
)
}
override fun testDisabled(context: ExtensionContext, reason: Optional<String>) {
findExtensions<TestWatcher>().forEach { it.testDisabled(context, reason) }
}
override fun testSuccessful(context: ExtensionContext) {
findExtensions<TestWatcher>().forEach { it.testSuccessful(context) }
}
override fun testAborted(context: ExtensionContext, cause: Throwable?) {
findExtensions<TestWatcher>().forEach { it.testAborted(context, cause) }
}
override fun testFailed(context: ExtensionContext?, cause: Throwable?) {
findExtensions<TestWatcher>().forEach { it.testFailed(context, cause) }
}
override fun beforeAll(context: ExtensionContext) {
findExtensions<BeforeAllCallback>().forEach { it.beforeAll(context) }
}
override fun beforeEach(context: ExtensionContext) {
findExtensions<BeforeEachCallback>().forEach { it.beforeEach(context) }
}
override fun beforeTestExecution(context: ExtensionContext) {
findExtensions<BeforeTestExecutionCallback>().forEach { it.beforeTestExecution(context) }
}
override fun afterTestExecution(context: ExtensionContext) {
findExtensions<AfterTestExecutionCallback>()
.reversed() // This is to retain JUnit 5's wrapping behaviour
.forEach { it.afterTestExecution(context) }
}
override fun afterEach(context: ExtensionContext) {
findExtensions<AfterEachCallback>()
.reversed() // This is to retain JUnit 5's wrapping behaviour
.forEach { it.afterEach(context) }
}
override fun afterAll(context: ExtensionContext) {
findExtensions<AfterAllCallback>()
.reversed() // This is to retain JUnit 5's wrapping behaviour
.forEach { it.afterAll(context) }
}
override fun handleTestExecutionException(context: ExtensionContext, throwable: Throwable) {
val exceptionHandlers = findExtensions<TestExecutionExceptionHandler>()
var finalThrowable: Throwable? = throwable
exceptionHandlers.forEach { handler ->
if (finalThrowable != null) {
finalThrowable = tryAndReturnException { handler.handleTestExecutionException(context, finalThrowable) }
}
}
if (finalThrowable != null) throw finalThrowable as Throwable
}
override fun handleBeforeAllMethodExecutionException(context: ExtensionContext, throwable: Throwable) {
val exceptionHandlers = findExtensions<LifecycleMethodExecutionExceptionHandler>()
var finalThrowable: Throwable? = throwable
exceptionHandlers.forEach { handler ->
if (finalThrowable != null) {
finalThrowable = tryAndReturnException {
handler.handleBeforeAllMethodExecutionException(context, finalThrowable)
}
}
}
if (finalThrowable != null) throw finalThrowable as Throwable
}
override fun handleBeforeEachMethodExecutionException(context: ExtensionContext, throwable: Throwable) {
val exceptionHandlers = findExtensions<LifecycleMethodExecutionExceptionHandler>()
var finalThrowable: Throwable? = throwable
exceptionHandlers.forEach { handler ->
if (finalThrowable != null) {
finalThrowable = tryAndReturnException {
handler.handleBeforeEachMethodExecutionException(context, finalThrowable)
}
}
}
if (finalThrowable != null) throw finalThrowable as Throwable
}
override fun handleAfterEachMethodExecutionException(context: ExtensionContext, throwable: Throwable) {
val exceptionHandlers = findExtensions<LifecycleMethodExecutionExceptionHandler>()
var finalThrowable: Throwable? = throwable
exceptionHandlers.forEach { handler ->
if (finalThrowable != null) {
finalThrowable = tryAndReturnException {
handler.handleAfterEachMethodExecutionException(context, finalThrowable)
}
}
}
if (finalThrowable != null) throw finalThrowable as Throwable
}
override fun handleAfterAllMethodExecutionException(context: ExtensionContext, throwable: Throwable) {
val exceptionHandlers = findExtensions<LifecycleMethodExecutionExceptionHandler>()
var finalThrowable: Throwable? = throwable
exceptionHandlers.forEach { handler ->
if (finalThrowable != null) {
finalThrowable = tryAndReturnException {
handler.handleAfterAllMethodExecutionException(context, finalThrowable)
}
}
}
if (finalThrowable != null) throw finalThrowable as Throwable
}
private inline fun <reified T> findExtensions(): List<T> {
return extensions.filterIsInstance<T>()
}
/*
*
* This is really awkward, but couldn't think of a better way to conform to JUnit 5's behaviour.
*
* Implementors must perform one of the following.
*
* 1. Swallow the supplied throwable, thereby preventing propagation.
* 2. Rethrow the supplied throwable as is.
* 3. Throw a new exception, potentially wrapping the supplied throwable.
*
* If the supplied throwable is swallowed, subsequent TestExecutionExceptionHandlers will not be invoked;
* otherwise, the next registered TestExecutionExceptionHandler (if there is one) will be invoked with any
* Throwable thrown by this handler.
**/
private inline fun tryAndReturnException(block: () -> Unit): Throwable? {
var thrown: Throwable? = null
try {
block()
} catch (e: Throwable) {
thrown = e
}
return thrown
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment